Source code for gigl.experimental.knowledge_graph_embedding.common.torchrec.batch
import abc
from dataclasses import dataclass, field, make_dataclass
from typing import Dict
import torch
from torchrec.streamable import Pipelineable
[docs]
class BatchBase(Pipelineable, abc.ABC):
    """
    This class extends https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/utils.py#L28
    to be reusable for any batch.
    This enables use with certain torchrec tools like pipelined training, which overlaps
    dataloading device transfer (copy to GPU), inter-device ocmmunications, and fwd/bkwd.
    """
    @abc.abstractmethod
[docs]
    def as_dict(self) -> Dict:
        raise NotImplementedError 
[docs]
    def to(self, device: torch.device, non_blocking: bool = False):
        args = {}
        for feature_name, feature_value in self.as_dict().items():
            args[feature_name] = feature_value.to(
                device=device, non_blocking=non_blocking
            )
        return self.__class__(**args) 
[docs]
    def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
        for feature_value in self.as_dict().values():
            feature_value.record_stream(stream) 
[docs]
    def pin_memory(self):
        args = {}
        for feature_name, feature_value in self.as_dict().items():
            args[feature_name] = feature_value.pin_memory()
        return self.__class__(**args) 
    def __repr__(self) -> str:
        def obj2str(v):
            return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
        return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
    @property
[docs]
    def batch_size(self) -> int:
        for tensor in self.as_dict().values():
            if tensor is None:
                continue
            if not isinstance(tensor, torch.Tensor):
                continue
            return tensor.shape[0]
        raise Exception("Could not determine batch size from tensors.") 
 
# TODO(nshah-sc): Consider folding BatchBase into this class.
@dataclass
[docs]
class DataclassBatch(BatchBase):
    """
    Makes it easy to create a Batch with some generic dataclass.
    """
    @classmethod
[docs]
    def feature_names(cls):
        return list(cls.__dataclass_fields__.keys()) 
[docs]
    def as_dict(self):
        return {
            feature_name: getattr(self, feature_name)
            for feature_name in self.feature_names()
            if hasattr(self, feature_name)
        } 
    @staticmethod
[docs]
    def from_schema(name: str, schema):
        """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
        return make_dataclass(
            cls_name=name,
            fields=[(name, torch.Tensor, field(default=None)) for name in schema.names],
            bases=(DataclassBatch,),
        ) 
    @staticmethod
[docs]
    def from_fields(name: str, fields: dict):
        return make_dataclass(
            cls_name=name,
            fields=[
                (_name, _type, field(default=None)) for _name, _type in fields.items()
            ],
            bases=(DataclassBatch,),
        )