Source code for gigl.src.inference.v1.lib.node_classification_inferencer
from functools import partial
from typing import Callable
import apache_beam as beam
from gigl.common import Uri, UriFactory
from gigl.src.common.graph_builder.abstract_graph_builder import GraphBuilder
from gigl.src.common.types.graph_data import NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from gigl.src.inference.v1.lib.base_inference_blueprint import BaseInferenceBlueprint
from gigl.src.inference.v1.lib.base_inferencer import (
    BaseInferencer,
    SupervisedNodeClassificationBaseInferencer,
)
from gigl.src.inference.v1.lib.transforms.utils import cache_mappings
from gigl.src.training.v1.lib.data_loaders.supervised_node_classification_data_loader import (
    SupervisedNodeClassificationBatch,
)
from snapchat.research.gbml import (
    flattened_graph_metadata_pb2,
    training_samples_schema_pb2,
)
[docs]
class NodeClassificationInferenceBlueprint(
    BaseInferenceBlueprint[
        training_samples_schema_pb2.SupervisedNodeClassificationSample,
        SupervisedNodeClassificationBatch,
    ]
):
    """
    Concrete NodeClassificationInferenceBlueprint class that implements functions in order
    to correctly compute and save inference results for SupervisedNodeClassification tasks.
    Implements Generics:
        RawSampleType = training_samples_schema_pb2.SupervisedNodeClassificationSample
        BatchType = SupervisedNodeClassificationBatch
    """
    def __init__(
        self,
        gbml_config_pb_wrapper: GbmlConfigPbWrapper,
        inferencer: BaseInferencer,
        graph_builder: GraphBuilder,
    ) -> None:
        # TODO (tzhao-sc): change these to args of functions s.t. we only initialize a property
        #              at the top level components
        self.__builder = graph_builder
        self.__gbml_config_pb_wrapper = gbml_config_pb_wrapper
        cache_mappings(gbml_config_pb_wrapper=self.__gbml_config_pb_wrapper)
        assert isinstance(inferencer, SupervisedNodeClassificationBaseInferencer)
        super().__init__(inferencer=inferencer)
[docs]
    def get_inference_data_tf_record_uri_prefixes(self) -> dict[NodeType, list[Uri]]:
        flattened_graph_metadata_pb_wrapper = (
            self.__gbml_config_pb_wrapper.flattened_graph_metadata_pb_wrapper
        )
        assert isinstance(
            flattened_graph_metadata_pb_wrapper.output_metadata,
            flattened_graph_metadata_pb2.SupervisedNodeClassificationOutput,
        )
        task_metadata_pb_wrapper = (
            self.__gbml_config_pb_wrapper.task_metadata_pb_wrapper
        )
        assert (
            task_metadata_pb_wrapper.task_metadata_type
            == TaskMetadataType.NODE_BASED_TASK
        ), f"Expected task metadata to be node based task, got {TaskMetadataType.NODE_BASED_TASK}"
        inferencer_node_types = (
            task_metadata_pb_wrapper.task_metadata_pb.node_based_task_metadata.supervision_node_types
        )
        if len(inferencer_node_types) != 1:
            raise NotImplementedError(
                f"Supervised node classification task expects one output node type, found {len(inferencer_node_types)} node types: {inferencer_node_types}"
            )
        return {
            NodeType(inferencer_node_types[0]): [
                UriFactory.create_uri(
                    flattened_graph_metadata_pb_wrapper.output_metadata.labeled_tfrecord_uri_prefix
                ),
                UriFactory.create_uri(
                    flattened_graph_metadata_pb_wrapper.output_metadata.unlabeled_tfrecord_uri_prefix
                ),
            ]
        } 
[docs]
    def get_tf_record_coder(self) -> beam.coders.ProtoCoder:
        coder = beam.coders.ProtoCoder(
            proto_message_type=training_samples_schema_pb2.SupervisedNodeClassificationSample
        )
        return coder 
[docs]
    def get_batch_generator_fn(self) -> Callable:
        return partial(
            SupervisedNodeClassificationBatch.process_raw_pyg_samples_and_collate_fn,
            builder=self.__builder,
            graph_metadata_pb_wrapper=self.__gbml_config_pb_wrapper.graph_metadata_pb_wrapper,
            preprocessed_metadata_pb_wrapper=self.__gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper,
        )