gigl.src.inference.v1.lib.base_inference_blueprint#

Attributes#

Classes#

BaseInferenceBlueprint

Abstract Base Class that needs to be implemented for inference dataflow pipelines

Module Contents#

class gigl.src.inference.v1.lib.base_inference_blueprint.BaseInferenceBlueprint(inferencer)[source]#

Bases: abc.ABC, Generic[RawSampleType, BatchType]

Abstract Base Class that needs to be implemented for inference dataflow pipelines to correctly compute and save inference results for GBML tasks, such as Supervised Node Classification, Node Anchor-Based Link Prediction, Supervised Link-Based Task Split, etc.

Implements Generics: - RawSampleType: The raw sample that will be parsed from get_tf_record_coder. - BatchType: The batch type needed for model inference (forward pass) for the specific task at hand (e.g RootedNodeNeighborhoodBatch).

Parameters:

inferencer (gigl.src.inference.v1.lib.base_inferencer.BaseInferencer)

abstract get_batch_generator_fn()[source]#
Returns:

The function specific to the batch type needed for the inference task at hand.

Return type:

Callable

static get_emb_table_schema(should_run_unenumeration=False)[source]#

Returns the schema for the BQ table that will house embeddings.

Returns: InferenceOutputBQTableSchema: Instance containing the schema and registered node field. See: https://beam.apache.org/documentation/io/built-in/google-bigquery/#creating-a-table-schema Example schema:

‘fields’: [

{‘name’: ‘source’, ‘type’: ‘STRING’, ‘mode’: ‘NULLABLE’}, {‘name’: ‘quote’, ‘type’: ‘STRING’, ‘mode’: ‘REQUIRED’}

]

Parameters:

should_run_unenumeration (bool)

Return type:

gigl.src.inference.v1.lib.inference_output_schema.InferenceOutputBigqueryTableSchema

abstract get_inference_data_tf_record_uri_prefixes()[source]#
Returns:

Dictionary of node type to the list of uri prefixes where to find tf record files that will be used for inference

Return type:

Dict[NodeType, List[Uri]]

get_inferer()[source]#

Returns a function that takes a DigestableBatchType object instance as input and yields TaggedOutputs with tags of either PREDICTION_TAGGED_OUTPUT_KEY or EMBEDDING_TAGGED_OUTPUT_KEY. The value is a Dict that can be directly written to BQ following the schemas defined in get_emb_table_schema for outputs with tag “embeddings” and get_pred_table_schema for outputs with tag PREDICTION_TAGGED_OUTPUT_KEY.

For example, the following will be mapped to the predictions table: pvalue.TaggedOutput(

PREDICTION_TAGGED_OUTPUT_KEY, {

‘source’: ‘Mahatma Gandhi’, ‘quote’: ‘My life is my message.’

}

)

Note that the output follows the schema presented in get_pred_table_schema.

Return type:

Callable[[BatchType], Iterable[apache_beam.pvalue.TaggedOutput]]

static get_pred_table_schema(should_run_unenumeration=False)[source]#

Returns the schema for the BQ table that will house predictions.

Returns: InferenceOutputBQTableSchema: Instance containing the schema and registered node field. See: https://beam.apache.org/documentation/io/built-in/google-bigquery/#creating-a-table-schema Example schema:

‘fields’: [

{‘name’: ‘source’, ‘type’: ‘STRING’, ‘mode’: ‘NULLABLE’}, {‘name’: ‘quote’, ‘type’: ‘STRING’, ‘mode’: ‘REQUIRED’}

]

Parameters:

should_run_unenumeration (bool)

Return type:

gigl.src.inference.v1.lib.inference_output_schema.InferenceOutputBigqueryTableSchema

abstract get_tf_record_coder()[source]#
Returns:

The coder used to parse the TFRecords to raw data samples of type RawSampleType

Return type:

beam.coders.ProtoCoder

gigl.src.inference.v1.lib.base_inference_blueprint.BatchType[source]#
gigl.src.inference.v1.lib.base_inference_blueprint.EMBEDDING_TAGGED_OUTPUT_KEY = 'embeddings'[source]#
gigl.src.inference.v1.lib.base_inference_blueprint.PREDICTION_TAGGED_OUTPUT_KEY = 'predictions'[source]#
gigl.src.inference.v1.lib.base_inference_blueprint.RawSampleType[source]#