gigl.experimental.knowledge_graph_embedding.lib.data.node_batch#

Classes#

NodeBatch

A class for representing a batch of nodes in a heterogeneous graph.

Functions#

collate_node_batch_from_range(nodes, ...)

Collates a batch of nodes into a NodeBatch.

Module Contents#

class gigl.experimental.knowledge_graph_embedding.lib.data.node_batch.NodeBatch[source]#

Bases: gigl.experimental.knowledge_graph_embedding.common.torchrec.batch.DataclassBatch

A class for representing a batch of nodes in a heterogeneous graph. These nodes share the same condensed_node_type, and inference is being run in the context of a single condensed edge type.

static build_data_loader(dataset, condensed_node_type, condensed_edge_type, graph_metadata, sampling_config, dataloader_config, pin_memory)[source]#
Parameters:
static from_node_tensors(nodes, condensed_node_type, condensed_edge_type, condensed_node_type_to_node_type_map)[source]#

Creates a NodeBatch from a range of nodes. Each batch will contain nodes of a single condensed node type. This is useful for inference when we want to collect embeddings for a range of nodes.

Parameters:
  • nodes (torch.Tensor) – torch.Tensor: A tensor containing the node IDs.

  • condensed_node_type (torch.Tensor) – torch.Tensor: A tensor representing the condensed node type.

  • condensed_edge_type (torch.Tensor) – torch.Tensor: A tensor representing the condensed edge type.

  • condensed_node_type_to_node_type_map (dict[gigl.src.common.types.graph_data.CondensedNodeType, gigl.src.common.types.graph_data.NodeType]) – dict[CondensedNodeType, NodeType]: A mapping from condensed node types to node types.

Returns:

The created NodeBatch.

Return type:

NodeBatch

to_node_tensors()[source]#

Reconstructs the tensors comprising the NodeBatch.

Parameters:

condensed_node_type_to_node_type_map (dict[CondensedNodeType, NodeType]) – A mapping from condensed node types to node types.

Returns:

A tuple containing:
  • nodes: torch.Tensor: The node IDs.

  • condensed_node_type: torch.Tensor: The condensed node type.

  • condensed_edge_type: torch.Tensor: The condensed edge type.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

condensed_edge_type: torch.Tensor[source]#
condensed_node_type: torch.Tensor[source]#
nodes: torchrec.KeyedJaggedTensor[source]#
gigl.experimental.knowledge_graph_embedding.lib.data.node_batch.collate_node_batch_from_range(nodes, condensed_node_type, condensed_edge_type, condensed_node_type_to_node_type_map)[source]#

Collates a batch of nodes into a NodeBatch. This is used for inference when we want to collect embeddings for a range of nodes.

Parameters:
  • nodes (Iterable[int]) – An iterable of node IDs.

  • condensed_node_type (CondensedNodeType) – The condensed node type for the batch.

  • condensed_edge_type (CondensedEdgeType) – The condensed edge type for the batch (relevant to inference).

  • condensed_node_type_to_node_type_map (dict[CondensedNodeType, NodeType]) – A mapping from condensed node types to node types.

Return type:

NodeBatch