Source code for gigl.src.inference.v1.lib.inference_blueprint_factory
from gigl.src.common.graph_builder.abstract_graph_builder import GraphBuilder
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from gigl.src.inference.v1.lib.base_inference_blueprint import BaseInferenceBlueprint
from gigl.src.inference.v1.lib.base_inferencer import (
    BaseInferencer,
    SupervisedNodeClassificationBaseInferencer,
)
from gigl.src.inference.v1.lib.node_anchor_based_link_prediction_inferencer import (
    NodeAnchorBasedLinkPredictionInferenceBlueprint,
)
from gigl.src.inference.v1.lib.node_classification_inferencer import (
    NodeClassificationInferenceBlueprint,
)
[docs]
class InferenceBlueprintFactory:
    @classmethod
[docs]
    def get_inference_blueprint(
        cls,
        gbml_config_pb_wrapper: GbmlConfigPbWrapper,
        inferencer_instance: BaseInferencer,
        graph_builder: GraphBuilder,
    ) -> BaseInferenceBlueprint:
        blueprint: BaseInferenceBlueprint
        task_metadata_type = (
            gbml_config_pb_wrapper.task_metadata_pb_wrapper.task_metadata_type
        )
        if task_metadata_type == TaskMetadataType.NODE_BASED_TASK:
            assert isinstance(
                inferencer_instance, SupervisedNodeClassificationBaseInferencer
            )
            blueprint = NodeClassificationInferenceBlueprint(
                gbml_config_pb_wrapper=gbml_config_pb_wrapper,
                inferencer=inferencer_instance,
                graph_builder=graph_builder,
            )
        elif (
            task_metadata_type
            == TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK
        ):
            blueprint = NodeAnchorBasedLinkPredictionInferenceBlueprint(
                gbml_config_pb_wrapper=gbml_config_pb_wrapper,
                inferencer=inferencer_instance,
                graph_builder=graph_builder,
            )
        else:
            raise TypeError(f"GBML task type not supported: {task_metadata_type}")
        return blueprint