gigl.src.inference.v1.lib.node_classification_inferencer#

Classes#

NodeClassificationInferenceBlueprint

Concrete NodeClassificationInferenceBlueprint class that implements functions in order

Module Contents#

class gigl.src.inference.v1.lib.node_classification_inferencer.NodeClassificationInferenceBlueprint(gbml_config_pb_wrapper, inferencer, graph_builder)[source]#

Bases: gigl.src.inference.v1.lib.base_inference_blueprint.BaseInferenceBlueprint[snapchat.research.gbml.training_samples_schema_pb2.SupervisedNodeClassificationSample, gigl.src.training.v1.lib.data_loaders.supervised_node_classification_data_loader.SupervisedNodeClassificationBatch]

Concrete NodeClassificationInferenceBlueprint class that implements functions in order to correctly compute and save inference results for SupervisedNodeClassification tasks.

Implements Generics:

RawSampleType = training_samples_schema_pb2.SupervisedNodeClassificationSample BatchType = SupervisedNodeClassificationBatch

Parameters:
get_batch_generator_fn()[source]#
Returns:

The function specific to the batch type needed for the inference task at hand.

Return type:

Callable

get_inference_data_tf_record_uri_prefixes()[source]#
Returns:

Dictionary of node type to the list of uri prefixes where to find tf record files that will be used for inference

Return type:

Dict[NodeType, List[Uri]]

get_tf_record_coder()[source]#
Returns:

The coder used to parse the TFRecords to raw data samples of type RawSampleType

Return type:

beam.coders.ProtoCoder