Source code for gigl.src.inference.v1.lib.inference_blueprint_factory

from gigl.src.common.graph_builder.abstract_graph_builder import GraphBuilder
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from gigl.src.inference.v1.lib.base_inference_blueprint import BaseInferenceBlueprint
from gigl.src.inference.v1.lib.base_inferencer import (
    BaseInferencer,
    SupervisedNodeClassificationBaseInferencer,
)
from gigl.src.inference.v1.lib.node_anchor_based_link_prediction_inferencer import (
    NodeAnchorBasedLinkPredictionInferenceBlueprint,
)
from gigl.src.inference.v1.lib.node_classification_inferencer import (
    NodeClassificationInferenceBlueprint,
)


[docs] class InferenceBlueprintFactory: @classmethod
[docs] def get_inference_blueprint( cls, gbml_config_pb_wrapper: GbmlConfigPbWrapper, inferencer_instance: BaseInferencer, graph_builder: GraphBuilder, ) -> BaseInferenceBlueprint: blueprint: BaseInferenceBlueprint task_metadata_type = ( gbml_config_pb_wrapper.task_metadata_pb_wrapper.task_metadata_type ) if task_metadata_type == TaskMetadataType.NODE_BASED_TASK: assert isinstance( inferencer_instance, SupervisedNodeClassificationBaseInferencer ) blueprint = NodeClassificationInferenceBlueprint( gbml_config_pb_wrapper=gbml_config_pb_wrapper, inferencer=inferencer_instance, graph_builder=graph_builder, ) elif ( task_metadata_type == TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK ): blueprint = NodeAnchorBasedLinkPredictionInferenceBlueprint( gbml_config_pb_wrapper=gbml_config_pb_wrapper, inferencer=inferencer_instance, graph_builder=graph_builder, ) else: raise TypeError(f"GBML task type not supported: {task_metadata_type}") return blueprint