Source code for gigl.experimental.knowledge_graph_embedding.common.iterator_utils

import itertools
from typing import Iterator


[docs] def batched(it: Iterator, n: int): """ Create batches of up to n elements from an iterator. Takes an input iterator and yields sub-iterators, each containing up to n elements. This is useful for processing data in chunks or creating batched operations for efficient data pipeline processing. Args: it (Iterator): The input iterator to batch. n (int): Maximum number of elements per batch. Must be >= 1. Yields: Iterator: Sub-iterators containing up to n elements from the input iterator. The last batch may contain fewer than n elements if the input iterator is exhausted. Raises: AssertionError: If n < 1. Example: >>> data = iter([1, 2, 3, 4, 5, 6, 7]) >>> for batch in batched(data, 3): ... print(list(batch)) [1, 2, 3] [4, 5, 6] [7] """ assert n >= 1 for x in it: yield itertools.chain((x,), itertools.islice(it, n - 1))