gigl.experimental.knowledge_graph_embedding.common.torchrec.batch#

Classes#

BatchBase

This class extends pytorch/torchrec

DataclassBatch

Makes it easy to create a Batch with some generic dataclass.

Module Contents#

class gigl.experimental.knowledge_graph_embedding.common.torchrec.batch.BatchBase[source]#

Bases: torchrec.streamable.Pipelineable, abc.ABC

This class extends pytorch/torchrec to be reusable for any batch.

This enables use with certain torchrec tools like pipelined training, which overlaps dataloading device transfer (copy to GPU), inter-device ocmmunications, and fwd/bkwd.

abstract as_dict()[source]#
Return type:

Dict

pin_memory()[source]#
record_stream(stream)[source]#

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

Parameters:

stream (torch.cuda.streams.Stream)

Return type:

None

to(device, non_blocking=False)[source]#

Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, to might return self or a copy of self. So please remember to use to with the assignment operator, for example, in = in.to(new_device).

Parameters:
  • device (torch.device)

  • non_blocking (bool)

property batch_size: int[source]#
Return type:

int

class gigl.experimental.knowledge_graph_embedding.common.torchrec.batch.DataclassBatch[source]#

Bases: BatchBase

Makes it easy to create a Batch with some generic dataclass.

as_dict()[source]#
classmethod feature_names()[source]#
static from_fields(name, fields)[source]#
Parameters:
  • name (str)

  • fields (dict)

static from_schema(name, schema)[source]#

Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.

Parameters:

name (str)