Source code for gigl.experimental.knowledge_graph_embedding.common.iterator_utils
import itertools
from typing import Iterator
[docs]
def batched(it: Iterator, n: int):
    """
    Create batches of up to n elements from an iterator.
    Takes an input iterator and yields sub-iterators, each containing up to n elements.
    This is useful for processing data in chunks or creating batched operations for
    efficient data pipeline processing.
    Args:
        it (Iterator): The input iterator to batch.
        n (int): Maximum number of elements per batch. Must be >= 1.
    Yields:
        Iterator: Sub-iterators containing up to n elements from the input iterator.
                 The last batch may contain fewer than n elements if the input
                 iterator is exhausted.
    Raises:
        AssertionError: If n < 1.
    Example:
        >>> data = iter([1, 2, 3, 4, 5, 6, 7])
        >>> for batch in batched(data, 3):
        ...     print(list(batch))
        [1, 2, 3]
        [4, 5, 6]
        [7]
    """
    assert n >= 1
    for x in it:
        yield itertools.chain((x,), itertools.islice(it, n - 1))