Source code for gigl.distributed.base_sampler

from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, Union

import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import DistNeighborSampler as GLTDistNeighborSampler
from graphlearn_torch.distributed.dist_feature import DistFeature
from graphlearn_torch.distributed.event_loop import wrap_torch_future
from graphlearn_torch.sampler import (
    HeteroSamplerOutput,
    NodeSamplerInput,
    SamplerOutput,
)
from graphlearn_torch.typing import NodeType, as_str
from graphlearn_torch.utils import reverse_edge_type

from gigl.distributed.sampler import (
    NEGATIVE_LABEL_METADATA_KEY,
    POSITIVE_LABEL_METADATA_KEY,
    ABLPNodeSamplerInput,
)
from gigl.utils.data_splitters import PADDING_NODE


def _stable_unique_preserve_order(nodes: torch.Tensor) -> torch.Tensor:
    """Return unique 1-D values while preserving first-occurrence order.

    Args:
        nodes: A 1-D tensor of node IDs (may contain duplicates).

    Returns:
        A 1-D tensor of unique node IDs in first-occurrence order.

    Raises:
        ValueError: If ``nodes`` is not 1-D.
    """
    if nodes.dim() != 1:
        raise ValueError(
            f"Expected a 1-D tensor of node ids, got shape {tuple(nodes.shape)}."
        )
    if nodes.numel() <= 1:
        return nodes

    unique_nodes, inverse = torch.unique(nodes, sorted=False, return_inverse=True)
    first_positions = torch.full(
        (unique_nodes.numel(),),
        fill_value=nodes.numel(),
        dtype=torch.long,
        device=nodes.device,
    )
    positions = torch.arange(nodes.numel(), device=nodes.device)
    first_positions.scatter_reduce_(
        0,
        inverse,
        positions,
        reduce="amin",
        include_self=True,
    )
    stable_order = torch.argsort(first_positions)
    return unique_nodes[stable_order]


@dataclass
[docs] class SampleLoopInputs: """Inputs prepared for the neighbor sampling loop in _sample_from_nodes. This dataclass holds the processed inputs that are passed to the core sampling loop. It allows _prepare_sample_loop_inputs to customize what nodes are sampled from and what metadata is attached to the output, without duplicating the sampling loop logic. Attributes: nodes_to_sample: For homogeneous graphs, a tensor of node IDs. For heterogeneous graphs, a dict mapping node types to tensors. For ABLP, this also includes supervision nodes (positive/negative labels). metadata: Metadata dict to attach to the sampler output (e.g., label tensors). """
[docs] nodes_to_sample: Union[torch.Tensor, dict[NodeType, torch.Tensor]]
[docs] metadata: dict[str, torch.Tensor]
[docs] class BaseDistNeighborSampler(GLTDistNeighborSampler): """Base class for GiGL distributed samplers. Extends GLT's DistNeighborSampler with shared utilities for preparing sampling inputs, including ABLP (anchor-based link prediction) support. Subclasses must override ``_sample_from_nodes`` with their specific sampling strategy (e.g., k-hop neighbor sampling, PPR-based sampling). """ def _prepare_sample_loop_inputs( self, inputs: NodeSamplerInput, ) -> SampleLoopInputs: """Prepare inputs for the sampling loop. Handles both standard NodeSamplerInput and ABLPNodeSamplerInput. For ABLP inputs, adds supervision nodes to the sampling seeds and builds label metadata. Args: inputs: Either a NodeSamplerInput or ABLPNodeSamplerInput. Returns: SampleLoopInputs containing the nodes to sample from and any metadata related to the task (e.g., label tensors for ABLP). """ input_seeds = inputs.node.to(self.device) input_type = inputs.input_type if isinstance(inputs, ABLPNodeSamplerInput): return self._prepare_ablp_inputs(inputs, input_seeds, input_type) # For homogeneous graphs (input_type is None), return tensor directly. # For heterogeneous graphs, return dict mapping node type to tensor. if input_type is None: return SampleLoopInputs( nodes_to_sample=input_seeds, metadata={}, ) return SampleLoopInputs( nodes_to_sample={input_type: input_seeds}, metadata={}, ) def _prepare_ablp_inputs( self, inputs: ABLPNodeSamplerInput, input_seeds: torch.Tensor, input_type: NodeType, ) -> SampleLoopInputs: """Prepare ABLP inputs with supervision nodes and label metadata. Args: inputs: The ABLPNodeSamplerInput containing label information. input_seeds: The anchor node seeds (already moved to device). input_type: The node type of the anchor seeds. Returns: SampleLoopInputs with supervision nodes included in nodes_to_sample and label tensors in metadata. """ # Since GLT swaps src/dst for edge_dir = "out", # and GiGL assumes that supervision edge types are always # (anchor_node_type, to, supervision_node_type), # we need to index into supervision edge types accordingly. label_edge_index = 0 if self.edge_dir == "in" else 2 # Build metadata and input nodes from positive/negative labels. # We need to sample from the supervision nodes as well, and ensure # that we are sampling from the correct node type. metadata: dict[str, torch.Tensor] = {} input_seeds_builder: dict[Union[str, NodeType], list[torch.Tensor]] = ( defaultdict(list) ) input_seeds_builder[input_type].append(input_seeds) for edge_type, label_tensor in inputs.positive_label_by_edge_types.items(): filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( self.device ) input_seeds_builder[edge_type[label_edge_index]].append( filtered_label_tensor ) # Update the metadata per positive label edge type. # We do this because GLT only supports dict[str, torch.Tensor] for metadata. metadata[f"{POSITIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}"] = ( label_tensor ) for edge_type, label_tensor in inputs.negative_label_by_edge_types.items(): filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( self.device ) input_seeds_builder[edge_type[label_edge_index]].append( filtered_label_tensor ) # Update the metadata per negative label edge type. # We do this because GLT only supports dict[str, torch.Tensor] for metadata. metadata[f"{NEGATIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}"] = ( label_tensor ) nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] = { # Keep first-occurrence order so anchor seeds remain at the front of # their node type; graph-transformer paths rely on that convention. node_type: _stable_unique_preserve_order( torch.cat(seeds, dim=0).to(self.device) ) for node_type, seeds in input_seeds_builder.items() } return SampleLoopInputs( nodes_to_sample=nodes_to_sample, metadata=metadata, ) async def _send_adapter( self, async_func, *args, **kwargs, ) -> Optional[SampleMessage]: """Override GLT's ``_send_adapter`` to call ``_collate_fn`` (corrected spelling). GLT's original calls ``self._colloate_fn`` (typo). This override is the only place in GiGL that references the typo — everything else uses ``_collate_fn``. Copied from ``graphlearn_torch.distributed.DistNeighborSampler._send_adapter`` (GLT 0.2.4) with the single change of ``_colloate_fn`` → ``_collate_fn``. """ sampler_output = await async_func(*args, **kwargs) res = await self._collate_fn(sampler_output) if self.channel is None: return res self.channel.send(res) return None async def _collate_fn( self, output: Union[SamplerOutput, HeteroSamplerOutput], ) -> SampleMessage: """Collect labels and features for the sampled subgraph into a SampleMessage. Copied from ``graphlearn_torch.distributed.DistNeighborSampler._colloate_fn`` (GLT 0.2.4). The method name preserves GLT's original typo so that this override is matched correctly at runtime. The only behavioural change from the GLT original is in the ``DistFeature`` label-fetch paths (both homogeneous and heterogeneous): GLT writes ``nlabels.T[0]``, which silently discards all label columns beyond the first and breaks multi-label node classification. This override writes the full ``nlabels`` tensor instead, avoiding the extra RPC call that a super()-then- re-fetch approach would require. The non-``DistFeature`` path (plain ``torch.Tensor`` labels) is unchanged — it never applied ``.T[0]``. # TODO (mkolodner-sc): Now that GiGL owns this method, investigate whether # post-processing steps in DistNeighborLoader._collate_fn can be folded in # here and simplified — e.g. set_missing_features (populating empty tensors # for node/edge features not fanned out to) and extract_metadata (stripping # #META. keys before to_hetero_data to work around a GLT bug where those # keys are misinterpreted as edge types). Args: output: The ``SamplerOutput`` or ``HeteroSamplerOutput`` returned by ``_sample_from_nodes``. Returns: A ``SampleMessage`` (``dict[str, torch.Tensor]``) ready to be sent over the sampling channel or returned directly to the loader. """ result_map: SampleMessage = {} is_hetero = self.dist_graph.data_cls == "hetero" result_map["#IS_HETERO"] = torch.LongTensor([int(is_hetero)]) if isinstance(output.metadata, dict): for k, v in output.metadata.items(): result_map[f"#META.{k}"] = v if is_hetero: for ntype, nodes in output.node.items(): result_map[f"{as_str(ntype)}.ids"] = nodes if output.num_sampled_nodes is not None: if ntype in output.num_sampled_nodes: result_map[f"{as_str(ntype)}.num_sampled_nodes"] = torch.tensor( output.num_sampled_nodes[ntype], device=self.device ) for etype, rows in output.row.items(): etype_str = as_str(etype) result_map[f"{etype_str}.rows"] = rows result_map[f"{etype_str}.cols"] = output.col[etype] if self.with_edge: result_map[f"{etype_str}.eids"] = output.edge[etype] if output.num_sampled_edges is not None: if etype in output.num_sampled_edges: result_map[f"{etype_str}.num_sampled_edges"] = torch.tensor( output.num_sampled_edges[etype], device=self.device ) input_type = output.input_type assert input_type is not None if not isinstance(input_type, tuple): if self.dist_node_labels is not None: if isinstance(self.dist_node_labels, DistFeature): fut = self.dist_node_labels.async_get( output.node[input_type], input_type ) nlabels = await wrap_torch_future(fut) # DistFeature always returns [N, K]. We collapse K=1 to 1-D # [N] to match GLT's convention and what downstream code # (e.g. CrossEntropyLoss) expects for data.y. Multi-label # (K>1) keeps the full 2-D matrix. # TODO (mkolodner-sc): Consider investigating always returning # 2-D — this may be a breaking change for single-label # training pipelines (e.g. CrossEntropyLoss expects 1-D data.y). result_map[f"{as_str(input_type)}.nlabels"] = ( nlabels if nlabels.shape[1] > 1 else nlabels.T[0] ) else: node_labels = self.dist_node_labels.get(input_type, None) if node_labels is not None: result_map[f"{as_str(input_type)}.nlabels"] = node_labels[ output.node[input_type].to(node_labels.device) ] if self.dist_node_feature is not None: if self.use_all2all: sorted_ntype = sorted(self.dist_node_feature.feature_pb.keys()) nfeat_dict = self.dist_node_feature.get_all2all( output, sorted_ntype ) for ntype, nfeats in nfeat_dict.items(): result_map[f"{as_str(ntype)}.nfeats"] = nfeats else: nfeat_fut_dict = {} for ntype, nodes in output.node.items(): nodes = nodes.to(torch.long) nfeat_fut_dict[ntype] = self.dist_node_feature.async_get( nodes, ntype ) for ntype, fut in nfeat_fut_dict.items(): nfeats = await wrap_torch_future(fut) result_map[f"{as_str(ntype)}.nfeats"] = nfeats if self.dist_edge_feature is not None and self.with_edge: efeat_fut_dict = {} for etype in self.edge_types: if self.edge_dir == "in": eids = result_map.get( f"{as_str(reverse_edge_type(etype))}.eids", None ) elif self.edge_dir == "out": eids = result_map.get(f"{as_str(etype)}.eids", None) if eids is not None: eids = eids.to(torch.long) efeat_fut_dict[etype] = self.dist_edge_feature.async_get( eids, etype ) for etype, fut in efeat_fut_dict.items(): efeats = await wrap_torch_future(fut) if self.edge_dir == "out": result_map[f"{as_str(etype)}.efeats"] = efeats elif self.edge_dir == "in": result_map[f"{as_str(reverse_edge_type(etype))}.efeats"] = ( efeats ) if output.batch is not None: for ntype, batch in output.batch.items(): result_map[f"{as_str(ntype)}.batch"] = batch else: result_map["ids"] = output.node result_map["rows"] = output.row result_map["cols"] = output.col if output.num_sampled_nodes is not None: result_map["num_sampled_nodes"] = torch.tensor( output.num_sampled_nodes, device=self.device ) result_map["num_sampled_edges"] = torch.tensor( output.num_sampled_edges, device=self.device ) if self.with_edge: result_map["eids"] = output.edge if self.dist_node_labels is not None: if isinstance(self.dist_node_labels, DistFeature): fut = self.dist_node_labels.async_get(output.node) nlabels = await wrap_torch_future(fut) # DistFeature always returns [N, K]. We collapse K=1 to 1-D # [N] to match GLT's convention and what downstream code # (e.g. CrossEntropyLoss) expects for data.y. Multi-label # (K>1) keeps the full 2-D matrix. # TODO (mkolodner-sc): Consider investigating always returning # 2-D — this may be a breaking change for single-label # training pipelines (e.g. CrossEntropyLoss expects 1-D data.y). result_map["nlabels"] = ( nlabels if nlabels.shape[1] > 1 else nlabels.T[0] ) else: result_map["nlabels"] = self.dist_node_labels[ output.node.to(self.dist_node_labels.device) ] if self.dist_node_feature is not None: fut = self.dist_node_feature.async_get(output.node) nfeats = await wrap_torch_future(fut) result_map["nfeats"] = nfeats if self.dist_edge_feature is not None: eids = result_map["eids"] fut = self.dist_edge_feature.async_get(eids) efeats = await wrap_torch_future(fut) result_map["efeats"] = efeats if output.batch is not None: result_map["batch"] = output.batch return result_map async def _sample_from_nodes( self, inputs: NodeSamplerInput, ) -> Union[SamplerOutput, HeteroSamplerOutput]: """Sample subgraph from seed nodes. Subclasses must override this method with their specific sampling strategy. Args: inputs: The seed nodes to sample from. Raises: NotImplementedError: Always — subclasses must override. """ raise NotImplementedError( f"{type(self).__name__} must override _sample_from_nodes." )