Source code for gigl.experimental.knowledge_graph_embedding.lib.data.node_batch

from __future__ import annotations

from dataclasses import dataclass
from functools import partial
from typing import Iterable

import torch
import torchrec

from gigl.experimental.knowledge_graph_embedding.common.torchrec.batch import (
    DataclassBatch,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.dataloader import (
    DataloaderConfig,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.sampling import (
    SamplingConfig,
)
from gigl.src.common.types.graph_data import (
    CondensedEdgeType,
    CondensedNodeType,
    NodeType,
)
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper


@dataclass
[docs] class NodeBatch(DataclassBatch): """ A class for representing a batch of nodes in a heterogeneous graph. These nodes share the same condensed_node_type, and inference is being run in the context of a single condensed edge type. """
[docs] nodes: torchrec.KeyedJaggedTensor
# node_ids: torch.Tensor # Denormalized for convenience. Same as `nodes.values()`.
[docs] condensed_node_type: torch.Tensor
[docs] condensed_edge_type: torch.Tensor
@staticmethod
[docs] def from_node_tensors( nodes: torch.Tensor, condensed_node_type: torch.Tensor, condensed_edge_type: torch.Tensor, condensed_node_type_to_node_type_map: dict[CondensedNodeType, NodeType], ) -> NodeBatch: """ Creates a NodeBatch from a range of nodes. Each batch will contain nodes of a single condensed node type. This is useful for inference when we want to collect embeddings for a range of nodes. Args: nodes: torch.Tensor: A tensor containing the node IDs. condensed_node_type: torch.Tensor: A tensor representing the condensed node type. condensed_edge_type: torch.Tensor: A tensor representing the condensed edge type. condensed_node_type_to_node_type_map: dict[CondensedNodeType, NodeType]: A mapping from condensed node types to node types. Returns: NodeBatch: The created NodeBatch. """ num_nodes = nodes.size() # Inclusive of start and end cnt_keys = sorted(list(condensed_node_type_to_node_type_map.keys())) lengths: dict[CondensedNodeType, torch.Tensor] = dict() values: dict[CondensedNodeType, torch.Tensor] = dict() for cnt_key in cnt_keys: lengths[cnt_key] = ( torch.ones(num_nodes, dtype=torch.int32) if cnt_key == condensed_node_type.item() else torch.zeros(num_nodes, dtype=torch.int32) ) values[cnt_key] = ( nodes if cnt_key == condensed_node_type.item() else torch.empty(0, dtype=torch.int32) ) nodes = torchrec.KeyedJaggedTensor( keys=cnt_keys, values=torch.cat([values[cnt] for cnt in cnt_keys], dim=0), lengths=torch.cat([lengths[cnt] for cnt in cnt_keys], dim=0), ) return NodeBatch( nodes=nodes, condensed_node_type=condensed_node_type, condensed_edge_type=condensed_edge_type, )
[docs] def to_node_tensors( self, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Reconstructs the tensors comprising the NodeBatch. Args: condensed_node_type_to_node_type_map (dict[CondensedNodeType, NodeType]): A mapping from condensed node types to node types. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: - nodes: torch.Tensor: The node IDs. - condensed_node_type: torch.Tensor: The condensed node type. - condensed_edge_type: torch.Tensor: The condensed edge type. """ lengths_per_key = torch.tensor(self.nodes.length_per_key()) assert lengths_per_key.argwhere().ravel().numel() == 1 assert lengths_per_key.argmax().item() == self.condensed_node_type.item() return self.nodes.values(), self.condensed_node_type, self.condensed_edge_type
@staticmethod
[docs] def build_data_loader( dataset: torch.utils.data.IterableDataset, condensed_node_type: CondensedNodeType, condensed_edge_type: CondensedEdgeType, graph_metadata: GraphMetadataPbWrapper, sampling_config: SamplingConfig, dataloader_config: DataloaderConfig, pin_memory: bool, ): return torch.utils.data.DataLoader( dataset=dataset, batch_size=sampling_config.positive_edge_batch_size, # todo(nshah): use inference batch size explicitly. num_workers=dataloader_config.num_workers, pin_memory=pin_memory, collate_fn=partial( collate_node_batch_from_range, condensed_node_type=condensed_node_type, condensed_edge_type=condensed_edge_type, condensed_node_type_to_node_type_map=graph_metadata.condensed_node_type_to_node_type_map, ), )
[docs] def collate_node_batch_from_range( nodes: Iterable[int], condensed_node_type: CondensedNodeType, condensed_edge_type: CondensedEdgeType, condensed_node_type_to_node_type_map: dict[CondensedNodeType, NodeType], ) -> NodeBatch: """ Collates a batch of nodes into a NodeBatch. This is used for inference when we want to collect embeddings for a range of nodes. Args: nodes (Iterable[int]): An iterable of node IDs. condensed_node_type (CondensedNodeType): The condensed node type for the batch. condensed_edge_type (CondensedEdgeType): The condensed edge type for the batch (relevant to inference). condensed_node_type_to_node_type_map (dict[CondensedNodeType, NodeType]): A mapping from condensed node types to node types. """ return NodeBatch.from_node_tensors( nodes=torch.tensor(nodes, dtype=torch.int32), condensed_node_type=torch.tensor(condensed_node_type, dtype=torch.int32), condensed_edge_type=torch.tensor(condensed_edge_type, dtype=torch.int32), condensed_node_type_to_node_type_map=condensed_node_type_to_node_type_map, )