Source code for gigl.experimental.knowledge_graph_embedding.common.torchrec.batch

import abc
from dataclasses import dataclass, field, make_dataclass
from typing import Dict

import torch
from torchrec.streamable import Pipelineable


[docs] class BatchBase(Pipelineable, abc.ABC): """ This class extends https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/utils.py#L28 to be reusable for any batch. This enables use with certain torchrec tools like pipelined training, which overlaps dataloading device transfer (copy to GPU), inter-device ocmmunications, and fwd/bkwd. """ @abc.abstractmethod
[docs] def as_dict(self) -> Dict: raise NotImplementedError
[docs] def to(self, device: torch.device, non_blocking: bool = False): args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.to( device=device, non_blocking=non_blocking ) return self.__class__(**args)
[docs] def record_stream(self, stream: torch.cuda.streams.Stream) -> None: for feature_value in self.as_dict().values(): feature_value.record_stream(stream)
[docs] def pin_memory(self): args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.pin_memory() return self.__class__(**args)
def __repr__(self) -> str: def obj2str(v): return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()]) @property
[docs] def batch_size(self) -> int: for tensor in self.as_dict().values(): if tensor is None: continue if not isinstance(tensor, torch.Tensor): continue return tensor.shape[0] raise Exception("Could not determine batch size from tensors.")
# TODO(nshah-sc): Consider folding BatchBase into this class. @dataclass
[docs] class DataclassBatch(BatchBase): """ Makes it easy to create a Batch with some generic dataclass. """ @classmethod
[docs] def feature_names(cls): return list(cls.__dataclass_fields__.keys())
[docs] def as_dict(self): return { feature_name: getattr(self, feature_name) for feature_name in self.feature_names() if hasattr(self, feature_name) }
@staticmethod
[docs] def from_schema(name: str, schema): """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" return make_dataclass( cls_name=name, fields=[(name, torch.Tensor, field(default=None)) for name in schema.names], bases=(DataclassBatch,), )
@staticmethod
[docs] def from_fields(name: str, fields: dict): return make_dataclass( cls_name=name, fields=[ (_name, _type, field(default=None)) for _name, _type in fields.items() ], bases=(DataclassBatch,), )