Source code for gigl.src.mocking.lib.tfrecord_io
from typing import Iterable, Optional, Sequence, TypeVar
from uuid import uuid4
import tensorflow as tf
from google.protobuf import message
from gigl.common import Uri
from gigl.common.logger import Logger
[docs]
def write_pb_tfrecord_shards_to_uri(
    pb_samples: Sequence[message.Message],
    uri_prefix: Uri,
    filename_prefix: str = "data",
    chunk_size=100,
    sample_type_for_logging: Optional[str] = "",
    raise_exception_if_no_pb_samples: bool = True,
):
    """
    Given a list of protobufs, chunk them and write them out to TFRecord files.
    """
    if raise_exception_if_no_pb_samples:
        assert len(
            pb_samples
        ), f"Found empty list of {sample_type_for_logging} samples to write to TFRecord files."
    def batch(list_of_items: Sequence[T], chunk_size: int) -> Iterable[Sequence[T]]:
        length_of_list = len(list_of_items)
        for idx in range(0, length_of_list, chunk_size):
            yield list_of_items[idx : min(idx + chunk_size, length_of_list)]
    uri_cls = type(uri_prefix)
    for pb_sample_batch in batch(list_of_items=pb_samples, chunk_size=chunk_size):
        with tf.io.TFRecordWriter(
            uri_cls.join(uri_prefix, f"{filename_prefix}-{str(uuid4())}.tfrecord").uri
        ) as writer:
            for sample in pb_sample_batch:
                writer.write(sample.SerializeToString())
    logger.info(
        f"Wrote {len(pb_samples)} {sample_type_for_logging} samples to {uri_prefix}"
    )