from __future__ import annotations
from dataclasses import dataclass
from functools import partial
from typing import Iterable
import torch
import torchrec
from gigl.experimental.knowledge_graph_embedding.common.torchrec.batch import (
DataclassBatch,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.dataloader import (
DataloaderConfig,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.sampling import (
SamplingConfig,
)
from gigl.src.common.types.graph_data import (
CondensedEdgeType,
CondensedNodeType,
NodeType,
)
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
@dataclass
[docs]
class NodeBatch(DataclassBatch):
"""
A class for representing a batch of nodes in a heterogeneous graph.
These nodes share the same condensed_node_type, and inference is
being run in the context of a single condensed edge type.
"""
[docs]
nodes: torchrec.KeyedJaggedTensor
# node_ids: torch.Tensor # Denormalized for convenience. Same as `nodes.values()`.
[docs]
condensed_node_type: torch.Tensor
[docs]
condensed_edge_type: torch.Tensor
@staticmethod
[docs]
def from_node_tensors(
nodes: torch.Tensor,
condensed_node_type: torch.Tensor,
condensed_edge_type: torch.Tensor,
condensed_node_type_to_node_type_map: dict[CondensedNodeType, NodeType],
) -> NodeBatch:
"""
Creates a NodeBatch from a range of nodes. Each batch will contain
nodes of a single condensed node type. This is useful for inference
when we want to collect embeddings for a range of nodes.
Args:
nodes: torch.Tensor: A tensor containing the node IDs.
condensed_node_type: torch.Tensor: A tensor representing the condensed node type.
condensed_edge_type: torch.Tensor: A tensor representing the condensed edge type.
condensed_node_type_to_node_type_map: dict[CondensedNodeType, NodeType]: A mapping from condensed node types to node types.
Returns:
NodeBatch: The created NodeBatch.
"""
num_nodes = nodes.size() # Inclusive of start and end
cnt_keys = sorted(list(condensed_node_type_to_node_type_map.keys()))
lengths: dict[CondensedNodeType, torch.Tensor] = dict()
values: dict[CondensedNodeType, torch.Tensor] = dict()
for cnt_key in cnt_keys:
lengths[cnt_key] = (
torch.ones(num_nodes, dtype=torch.int32)
if cnt_key == condensed_node_type.item()
else torch.zeros(num_nodes, dtype=torch.int32)
)
values[cnt_key] = (
nodes
if cnt_key == condensed_node_type.item()
else torch.empty(0, dtype=torch.int32)
)
nodes = torchrec.KeyedJaggedTensor(
keys=cnt_keys,
values=torch.cat([values[cnt] for cnt in cnt_keys], dim=0),
lengths=torch.cat([lengths[cnt] for cnt in cnt_keys], dim=0),
)
return NodeBatch(
nodes=nodes,
condensed_node_type=condensed_node_type,
condensed_edge_type=condensed_edge_type,
)
[docs]
def to_node_tensors(
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reconstructs the tensors comprising the NodeBatch.
Args:
condensed_node_type_to_node_type_map (dict[CondensedNodeType, NodeType]): A mapping from condensed node types to node types.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
- nodes: torch.Tensor: The node IDs.
- condensed_node_type: torch.Tensor: The condensed node type.
- condensed_edge_type: torch.Tensor: The condensed edge type.
"""
lengths_per_key = torch.tensor(self.nodes.length_per_key())
assert lengths_per_key.argwhere().ravel().numel() == 1
assert lengths_per_key.argmax().item() == self.condensed_node_type.item()
return self.nodes.values(), self.condensed_node_type, self.condensed_edge_type
@staticmethod
[docs]
def build_data_loader(
dataset: torch.utils.data.IterableDataset,
condensed_node_type: CondensedNodeType,
condensed_edge_type: CondensedEdgeType,
graph_metadata: GraphMetadataPbWrapper,
sampling_config: SamplingConfig,
dataloader_config: DataloaderConfig,
pin_memory: bool,
):
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=sampling_config.positive_edge_batch_size, # todo(nshah): use inference batch size explicitly.
num_workers=dataloader_config.num_workers,
pin_memory=pin_memory,
collate_fn=partial(
collate_node_batch_from_range,
condensed_node_type=condensed_node_type,
condensed_edge_type=condensed_edge_type,
condensed_node_type_to_node_type_map=graph_metadata.condensed_node_type_to_node_type_map,
),
)
[docs]
def collate_node_batch_from_range(
nodes: Iterable[int],
condensed_node_type: CondensedNodeType,
condensed_edge_type: CondensedEdgeType,
condensed_node_type_to_node_type_map: dict[CondensedNodeType, NodeType],
) -> NodeBatch:
"""
Collates a batch of nodes into a NodeBatch.
This is used for inference when we want to collect embeddings for a range of nodes.
Args:
nodes (Iterable[int]): An iterable of node IDs.
condensed_node_type (CondensedNodeType): The condensed node type for the batch.
condensed_edge_type (CondensedEdgeType): The condensed edge type for the batch (relevant to inference).
condensed_node_type_to_node_type_map (dict[CondensedNodeType, NodeType]): A mapping from condensed node types to node types.
"""
return NodeBatch.from_node_tensors(
nodes=torch.tensor(nodes, dtype=torch.int32),
condensed_node_type=torch.tensor(condensed_node_type, dtype=torch.int32),
condensed_edge_type=torch.tensor(condensed_edge_type, dtype=torch.int32),
condensed_node_type_to_node_type_map=condensed_node_type_to_node_type_map,
)