gigl.src.inference.v1.lib.base_inference_blueprint#
Attributes#
Classes#
Abstract Base Class that needs to be implemented for inference dataflow pipelines |
Module Contents#
- class gigl.src.inference.v1.lib.base_inference_blueprint.BaseInferenceBlueprint(inferencer)[source]#
Bases:
abc.ABC
,Generic
[RawSampleType
,BatchType
]Abstract Base Class that needs to be implemented for inference dataflow pipelines to correctly compute and save inference results for GBML tasks, such as Supervised Node Classification, Node Anchor-Based Link Prediction, Supervised Link-Based Task Split, etc.
Implements Generics: - RawSampleType: The raw sample that will be parsed from get_tf_record_coder. - BatchType: The batch type needed for model inference (forward pass) for the specific task at hand (e.g RootedNodeNeighborhoodBatch).
- Parameters:
inferencer (gigl.src.inference.v1.lib.base_inferencer.BaseInferencer)
- abstract get_batch_generator_fn()[source]#
- Returns:
The function specific to the batch type needed for the inference task at hand.
- Return type:
Callable
- static get_emb_table_schema(should_run_unenumeration=False)[source]#
Returns the schema for the BQ table that will house embeddings.
Returns: InferenceOutputBQTableSchema: Instance containing the schema and registered node field. See: https://beam.apache.org/documentation/io/built-in/google-bigquery/#creating-a-table-schema Example schema:
- ‘fields’: [
{‘name’: ‘source’, ‘type’: ‘STRING’, ‘mode’: ‘NULLABLE’}, {‘name’: ‘quote’, ‘type’: ‘STRING’, ‘mode’: ‘REQUIRED’}
]
- Parameters:
should_run_unenumeration (bool)
- Return type:
gigl.src.inference.v1.lib.inference_output_schema.InferenceOutputBigqueryTableSchema
- abstract get_inference_data_tf_record_uri_prefixes()[source]#
- Returns:
Dictionary of node type to the list of uri prefixes where to find tf record files that will be used for inference
- Return type:
Dict[NodeType, List[Uri]]
- get_inferer()[source]#
Returns a function that takes a DigestableBatchType object instance as input and yields TaggedOutputs with tags of either PREDICTION_TAGGED_OUTPUT_KEY or EMBEDDING_TAGGED_OUTPUT_KEY. The value is a Dict that can be directly written to BQ following the schemas defined in get_emb_table_schema for outputs with tag “embeddings” and get_pred_table_schema for outputs with tag PREDICTION_TAGGED_OUTPUT_KEY.
For example, the following will be mapped to the predictions table: pvalue.TaggedOutput(
PREDICTION_TAGGED_OUTPUT_KEY, {
‘source’: ‘Mahatma Gandhi’, ‘quote’: ‘My life is my message.’
}
)
Note that the output follows the schema presented in get_pred_table_schema.
- Return type:
Callable[[BatchType], Iterable[apache_beam.pvalue.TaggedOutput]]
- static get_pred_table_schema(should_run_unenumeration=False)[source]#
Returns the schema for the BQ table that will house predictions.
Returns: InferenceOutputBQTableSchema: Instance containing the schema and registered node field. See: https://beam.apache.org/documentation/io/built-in/google-bigquery/#creating-a-table-schema Example schema:
- ‘fields’: [
{‘name’: ‘source’, ‘type’: ‘STRING’, ‘mode’: ‘NULLABLE’}, {‘name’: ‘quote’, ‘type’: ‘STRING’, ‘mode’: ‘REQUIRED’}
]
- Parameters:
should_run_unenumeration (bool)
- Return type:
gigl.src.inference.v1.lib.inference_output_schema.InferenceOutputBigqueryTableSchema
- gigl.src.inference.v1.lib.base_inference_blueprint.EMBEDDING_TAGGED_OUTPUT_KEY = 'embeddings'[source]#