Source code for gigl.src.inference.v1.lib.node_classification_inferencer

from functools import partial
from typing import Callable, Dict, List

import apache_beam as beam

from gigl.common import Uri, UriFactory
from gigl.src.common.graph_builder.abstract_graph_builder import GraphBuilder
from gigl.src.common.types.graph_data import NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from gigl.src.inference.v1.lib.base_inference_blueprint import BaseInferenceBlueprint
from gigl.src.inference.v1.lib.base_inferencer import (
    BaseInferencer,
    SupervisedNodeClassificationBaseInferencer,
)
from gigl.src.inference.v1.lib.transforms.utils import cache_mappings
from gigl.src.training.v1.lib.data_loaders.supervised_node_classification_data_loader import (
    SupervisedNodeClassificationBatch,
)
from snapchat.research.gbml import (
    flattened_graph_metadata_pb2,
    training_samples_schema_pb2,
)


[docs] class NodeClassificationInferenceBlueprint( BaseInferenceBlueprint[ training_samples_schema_pb2.SupervisedNodeClassificationSample, SupervisedNodeClassificationBatch, ] ): """ Concrete NodeClassificationInferenceBlueprint class that implements functions in order to correctly compute and save inference results for SupervisedNodeClassification tasks. Implements Generics: RawSampleType = training_samples_schema_pb2.SupervisedNodeClassificationSample BatchType = SupervisedNodeClassificationBatch """ def __init__( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, inferencer: BaseInferencer, graph_builder: GraphBuilder, ) -> None: # TODO (tzhao-sc): change these to args of functions s.t. we only initialize a property # at the top level components self.__builder = graph_builder self.__gbml_config_pb_wrapper = gbml_config_pb_wrapper cache_mappings(gbml_config_pb_wrapper=self.__gbml_config_pb_wrapper) assert isinstance(inferencer, SupervisedNodeClassificationBaseInferencer) super().__init__(inferencer=inferencer)
[docs] def get_inference_data_tf_record_uri_prefixes(self) -> Dict[NodeType, List[Uri]]: flattened_graph_metadata_pb_wrapper = ( self.__gbml_config_pb_wrapper.flattened_graph_metadata_pb_wrapper ) assert isinstance( flattened_graph_metadata_pb_wrapper.output_metadata, flattened_graph_metadata_pb2.SupervisedNodeClassificationOutput, ) task_metadata_pb_wrapper = ( self.__gbml_config_pb_wrapper.task_metadata_pb_wrapper ) assert ( task_metadata_pb_wrapper.task_metadata_type == TaskMetadataType.NODE_BASED_TASK ), f"Expected task metadata to be node based task, got {TaskMetadataType.NODE_BASED_TASK}" inferencer_node_types = ( task_metadata_pb_wrapper.task_metadata_pb.node_based_task_metadata.supervision_node_types ) if len(inferencer_node_types) != 1: raise NotImplementedError( f"Supervised node classification task expects one output node type, found {len(inferencer_node_types)} node types: {inferencer_node_types}" ) return { NodeType(inferencer_node_types[0]): [ UriFactory.create_uri( flattened_graph_metadata_pb_wrapper.output_metadata.labeled_tfrecord_uri_prefix ), UriFactory.create_uri( flattened_graph_metadata_pb_wrapper.output_metadata.unlabeled_tfrecord_uri_prefix ), ] }
[docs] def get_tf_record_coder(self) -> beam.coders.ProtoCoder: coder = beam.coders.ProtoCoder( proto_message_type=training_samples_schema_pb2.SupervisedNodeClassificationSample ) return coder
[docs] def get_batch_generator_fn(self) -> Callable: return partial( SupervisedNodeClassificationBatch.process_raw_pyg_samples_and_collate_fn, builder=self.__builder, graph_metadata_pb_wrapper=self.__gbml_config_pb_wrapper.graph_metadata_pb_wrapper, preprocessed_metadata_pb_wrapper=self.__gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, )