Source code for gigl.src.inference.v1.lib.transforms.batch_generator
from typing import Iterable, Union
import apache_beam as beam
from gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader import (
    RootedNodeNeighborhoodBatch,
)
from gigl.src.training.v1.lib.data_loaders.supervised_node_classification_data_loader import (
    SupervisedNodeClassificationBatch,
)
from snapchat.research.gbml import training_samples_schema_pb2
[docs]
RawBatchType = Union[
    training_samples_schema_pb2.RootedNodeNeighborhood,
    training_samples_schema_pb2.SupervisedNodeClassificationSample,
] 
[docs]
InferenceBatchType = Union[
    SupervisedNodeClassificationBatch, RootedNodeNeighborhoodBatch
] 
[docs]
class BatchProcessorDoFn(beam.DoFn):
    def __init__(
        self,
        batch_generator_fn,
    ):
[docs]
        self.batch_generator_fn = batch_generator_fn 
[docs]
    def process(self, element: list[RawBatchType]) -> Iterable[InferenceBatchType]:
        yield self.batch_generator_fn(
            batch=element,
        )