Source code for gigl.src.inference.v1.lib.transforms.batch_generator

from typing import Iterable, List, Union

import apache_beam as beam

from gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader import (
    RootedNodeNeighborhoodBatch,
)
from gigl.src.training.v1.lib.data_loaders.supervised_node_classification_data_loader import (
    SupervisedNodeClassificationBatch,
)
from snapchat.research.gbml import training_samples_schema_pb2

[docs] RawBatchType = Union[ training_samples_schema_pb2.RootedNodeNeighborhood, training_samples_schema_pb2.SupervisedNodeClassificationSample, ]
[docs] InferenceBatchType = Union[ SupervisedNodeClassificationBatch, RootedNodeNeighborhoodBatch ]
[docs] class BatchProcessorDoFn(beam.DoFn): def __init__( self, batch_generator_fn, ):
[docs] self.batch_generator_fn = batch_generator_fn
[docs] def process(self, element: List[RawBatchType]) -> Iterable[InferenceBatchType]: yield self.batch_generator_fn( batch=element, )