gigl.experimental.knowledge_graph_embedding.common.torchrec.batch#
Classes#
This class extends pytorch/torchrec |
|
Makes it easy to create a Batch with some generic dataclass. |
Module Contents#
- class gigl.experimental.knowledge_graph_embedding.common.torchrec.batch.BatchBase[source]#
Bases:
torchrec.streamable.Pipelineable
,abc.ABC
This class extends pytorch/torchrec to be reusable for any batch.
This enables use with certain torchrec tools like pipelined training, which overlaps dataloading device transfer (copy to GPU), inter-device ocmmunications, and fwd/bkwd.
- record_stream(stream)[source]#
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- Parameters:
stream (torch.cuda.streams.Stream)
- Return type:
None
- to(device, non_blocking=False)[source]#
Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, to might return self or a copy of self. So please remember to use to with the assignment operator, for example, in = in.to(new_device).
- Parameters:
device (torch.device)
non_blocking (bool)