from abc import ABC, abstractmethod
from typing import Callable, Dict, Generic, Iterable, List, Optional, TypeVar
import apache_beam as beam
from apache_beam import pvalue
from gigl.common import Uri
from gigl.src.common.types.graph_data import NodeType
# Raw data format that will be read from TFRecord files i.e. a proto class
from gigl.src.inference.v1.lib.base_inferencer import BaseInferencer
from gigl.src.inference.v1.lib.inference_output_schema import (
DEFAULT_EMBEDDING_FIELD,
DEFAULT_EMBEDDINGS_TABLE_SCHEMA,
DEFAULT_NODE_ID_FIELD,
DEFAULT_PREDICTION_FIELD,
DEFAULT_PREDICTIONS_TABLE_SCHEMA,
UNENUMERATED_EMBEDDINGS_TABLE_SCHEMA,
UNENUMERATED_PREDICTIONS_TABLE_SCHEMA,
InferenceOutputBigqueryTableSchema,
)
[docs]
RawSampleType = TypeVar("RawSampleType")
# A batch representation of samples above that can be used to make inference more efficient.
[docs]
BatchType = TypeVar("BatchType")
[docs]
PREDICTION_TAGGED_OUTPUT_KEY = "predictions"
[docs]
EMBEDDING_TAGGED_OUTPUT_KEY = "embeddings"
[docs]
class BaseInferenceBlueprint(
ABC,
Generic[
RawSampleType,
BatchType,
],
):
"""
Abstract Base Class that needs to be implemented for inference dataflow pipelines
to correctly compute and save inference results for GBML tasks, such as
Supervised Node Classification, Node Anchor-Based Link Prediction,
Supervised Link-Based Task Split, etc.
Implements Generics:
- RawSampleType: The raw sample that will be parsed from get_tf_record_coder.
- BatchType: The batch type needed for model inference (forward pass) for the specific task at hand (e.g RootedNodeNeighborhoodBatch).
"""
def __init__(self, inferencer: BaseInferencer):
self._inferencer = inferencer
[docs]
def get_inferer(
self,
) -> Callable[[BatchType], Iterable[pvalue.TaggedOutput]]:
"""
Returns a function that takes a DigestableBatchType object instance as input and yields TaggedOutputs
with tags of either PREDICTION_TAGGED_OUTPUT_KEY or EMBEDDING_TAGGED_OUTPUT_KEY. The value is a Dict
that can be directly written to BQ following the schemas defined in get_emb_table_schema for outputs
with tag "embeddings" and get_pred_table_schema for outputs with tag PREDICTION_TAGGED_OUTPUT_KEY.
For example, the following will be mapped to the predictions table:
pvalue.TaggedOutput(
PREDICTION_TAGGED_OUTPUT_KEY,
{
'source': 'Mahatma Gandhi', 'quote': 'My life is my message.'
}
)
Note that the output follows the schema presented in get_pred_table_schema.
"""
def _make_inference(
batch: BatchType,
) -> Iterable[pvalue.TaggedOutput]:
infer_batch_results = self._inferencer.infer_batch(batch=batch)
for i, node in enumerate(batch.root_nodes): # type: ignore
pred: Optional[List[int]] = None
emb: Optional[List[float]] = None
predictions = infer_batch_results.predictions
embeddings = infer_batch_results.embeddings
if predictions is not None:
pred = predictions[i].tolist()
if embeddings is not None:
emb = embeddings[i].tolist()
if pred is not None:
yield pvalue.TaggedOutput(
PREDICTION_TAGGED_OUTPUT_KEY,
{
DEFAULT_NODE_ID_FIELD: node.id,
DEFAULT_PREDICTION_FIELD: pred,
},
)
if emb is not None:
yield pvalue.TaggedOutput(
EMBEDDING_TAGGED_OUTPUT_KEY,
{
DEFAULT_NODE_ID_FIELD: node.id,
DEFAULT_EMBEDDING_FIELD: emb,
},
)
return _make_inference
@staticmethod
[docs]
def get_emb_table_schema(
should_run_unenumeration: bool = False,
) -> InferenceOutputBigqueryTableSchema:
"""
Returns the schema for the BQ table that will house embeddings.
Returns:
InferenceOutputBQTableSchema: Instance containing the schema and registered node field.
See: https://beam.apache.org/documentation/io/built-in/google-bigquery/#creating-a-table-schema
Example schema:
'fields': [
{'name': 'source', 'type': 'STRING', 'mode': 'NULLABLE'},
{'name': 'quote', 'type': 'STRING', 'mode': 'REQUIRED'}
]
"""
if should_run_unenumeration:
return UNENUMERATED_EMBEDDINGS_TABLE_SCHEMA
else:
return DEFAULT_EMBEDDINGS_TABLE_SCHEMA
@staticmethod
[docs]
def get_pred_table_schema(
should_run_unenumeration: bool = False,
) -> InferenceOutputBigqueryTableSchema:
"""
Returns the schema for the BQ table that will house predictions.
Returns:
InferenceOutputBQTableSchema: Instance containing the schema and registered node field.
See: https://beam.apache.org/documentation/io/built-in/google-bigquery/#creating-a-table-schema
Example schema:
'fields': [
{'name': 'source', 'type': 'STRING', 'mode': 'NULLABLE'},
{'name': 'quote', 'type': 'STRING', 'mode': 'REQUIRED'}
]
"""
if should_run_unenumeration:
return UNENUMERATED_PREDICTIONS_TABLE_SCHEMA
else:
return DEFAULT_PREDICTIONS_TABLE_SCHEMA
@abstractmethod
[docs]
def get_inference_data_tf_record_uri_prefixes(self) -> Dict[NodeType, List[Uri]]:
"""
Returns:
Dict[NodeType, List[Uri]]: Dictionary of node type to the list of uri prefixes where to find tf record files
that will be used for inference
"""
raise NotImplementedError
@abstractmethod
[docs]
def get_tf_record_coder(self) -> beam.coders.ProtoCoder:
"""
Returns:
beam.coders.ProtoCoder: The coder used to parse the TFRecords to raw data samples of
type RawSampleType
"""
raise NotImplementedError
@abstractmethod
[docs]
def get_batch_generator_fn(
self,
) -> Callable:
"""
Returns:
Callable: The function specific to the batch type needed for the inference task at hand.
"""
raise NotImplementedError