from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, Union
import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import DistNeighborSampler as GLTDistNeighborSampler
from graphlearn_torch.distributed.dist_feature import DistFeature
from graphlearn_torch.distributed.event_loop import wrap_torch_future
from graphlearn_torch.sampler import (
HeteroSamplerOutput,
NodeSamplerInput,
SamplerOutput,
)
from graphlearn_torch.typing import NodeType, as_str
from graphlearn_torch.utils import reverse_edge_type
from gigl.distributed.sampler import (
NEGATIVE_LABEL_METADATA_KEY,
POSITIVE_LABEL_METADATA_KEY,
ABLPNodeSamplerInput,
)
from gigl.utils.data_splitters import PADDING_NODE
def _stable_unique_preserve_order(nodes: torch.Tensor) -> torch.Tensor:
"""Return unique 1-D values while preserving first-occurrence order.
Args:
nodes: A 1-D tensor of node IDs (may contain duplicates).
Returns:
A 1-D tensor of unique node IDs in first-occurrence order.
Raises:
ValueError: If ``nodes`` is not 1-D.
"""
if nodes.dim() != 1:
raise ValueError(
f"Expected a 1-D tensor of node ids, got shape {tuple(nodes.shape)}."
)
if nodes.numel() <= 1:
return nodes
unique_nodes, inverse = torch.unique(nodes, sorted=False, return_inverse=True)
first_positions = torch.full(
(unique_nodes.numel(),),
fill_value=nodes.numel(),
dtype=torch.long,
device=nodes.device,
)
positions = torch.arange(nodes.numel(), device=nodes.device)
first_positions.scatter_reduce_(
0,
inverse,
positions,
reduce="amin",
include_self=True,
)
stable_order = torch.argsort(first_positions)
return unique_nodes[stable_order]
@dataclass
[docs]
class BaseDistNeighborSampler(GLTDistNeighborSampler):
"""Base class for GiGL distributed samplers.
Extends GLT's DistNeighborSampler with shared utilities for preparing
sampling inputs, including ABLP (anchor-based link prediction) support.
Subclasses must override ``_sample_from_nodes`` with their specific
sampling strategy (e.g., k-hop neighbor sampling, PPR-based sampling).
"""
def _prepare_sample_loop_inputs(
self,
inputs: NodeSamplerInput,
) -> SampleLoopInputs:
"""Prepare inputs for the sampling loop.
Handles both standard NodeSamplerInput and ABLPNodeSamplerInput.
For ABLP inputs, adds supervision nodes to the sampling seeds and
builds label metadata.
Args:
inputs: Either a NodeSamplerInput or ABLPNodeSamplerInput.
Returns:
SampleLoopInputs containing the nodes to sample from and any
metadata related to the task (e.g., label tensors for ABLP).
"""
input_seeds = inputs.node.to(self.device)
input_type = inputs.input_type
if isinstance(inputs, ABLPNodeSamplerInput):
return self._prepare_ablp_inputs(inputs, input_seeds, input_type)
# For homogeneous graphs (input_type is None), return tensor directly.
# For heterogeneous graphs, return dict mapping node type to tensor.
if input_type is None:
return SampleLoopInputs(
nodes_to_sample=input_seeds,
metadata={},
)
return SampleLoopInputs(
nodes_to_sample={input_type: input_seeds},
metadata={},
)
def _prepare_ablp_inputs(
self,
inputs: ABLPNodeSamplerInput,
input_seeds: torch.Tensor,
input_type: NodeType,
) -> SampleLoopInputs:
"""Prepare ABLP inputs with supervision nodes and label metadata.
Args:
inputs: The ABLPNodeSamplerInput containing label information.
input_seeds: The anchor node seeds (already moved to device).
input_type: The node type of the anchor seeds.
Returns:
SampleLoopInputs with supervision nodes included in nodes_to_sample
and label tensors in metadata.
"""
# Since GLT swaps src/dst for edge_dir = "out",
# and GiGL assumes that supervision edge types are always
# (anchor_node_type, to, supervision_node_type),
# we need to index into supervision edge types accordingly.
label_edge_index = 0 if self.edge_dir == "in" else 2
# Build metadata and input nodes from positive/negative labels.
# We need to sample from the supervision nodes as well, and ensure
# that we are sampling from the correct node type.
metadata: dict[str, torch.Tensor] = {}
input_seeds_builder: dict[Union[str, NodeType], list[torch.Tensor]] = (
defaultdict(list)
)
input_seeds_builder[input_type].append(input_seeds)
for edge_type, label_tensor in inputs.positive_label_by_edge_types.items():
filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to(
self.device
)
input_seeds_builder[edge_type[label_edge_index]].append(
filtered_label_tensor
)
# Update the metadata per positive label edge type.
# We do this because GLT only supports dict[str, torch.Tensor] for metadata.
metadata[f"{POSITIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}"] = (
label_tensor
)
for edge_type, label_tensor in inputs.negative_label_by_edge_types.items():
filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to(
self.device
)
input_seeds_builder[edge_type[label_edge_index]].append(
filtered_label_tensor
)
# Update the metadata per negative label edge type.
# We do this because GLT only supports dict[str, torch.Tensor] for metadata.
metadata[f"{NEGATIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}"] = (
label_tensor
)
nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] = {
# Keep first-occurrence order so anchor seeds remain at the front of
# their node type; graph-transformer paths rely on that convention.
node_type: _stable_unique_preserve_order(
torch.cat(seeds, dim=0).to(self.device)
)
for node_type, seeds in input_seeds_builder.items()
}
return SampleLoopInputs(
nodes_to_sample=nodes_to_sample,
metadata=metadata,
)
async def _send_adapter(
self,
async_func,
*args,
**kwargs,
) -> Optional[SampleMessage]:
"""Override GLT's ``_send_adapter`` to call ``_collate_fn`` (corrected spelling).
GLT's original calls ``self._colloate_fn`` (typo). This override is the
only place in GiGL that references the typo — everything else uses
``_collate_fn``.
Copied from ``graphlearn_torch.distributed.DistNeighborSampler._send_adapter``
(GLT 0.2.4) with the single change of ``_colloate_fn`` → ``_collate_fn``.
"""
sampler_output = await async_func(*args, **kwargs)
res = await self._collate_fn(sampler_output)
if self.channel is None:
return res
self.channel.send(res)
return None
async def _collate_fn(
self,
output: Union[SamplerOutput, HeteroSamplerOutput],
) -> SampleMessage:
"""Collect labels and features for the sampled subgraph into a SampleMessage.
Copied from ``graphlearn_torch.distributed.DistNeighborSampler._colloate_fn``
(GLT 0.2.4). The method name preserves GLT's original typo so that this
override is matched correctly at runtime.
The only behavioural change from the GLT original is in the ``DistFeature``
label-fetch paths (both homogeneous and heterogeneous): GLT writes
``nlabels.T[0]``, which silently discards all label columns beyond the first
and breaks multi-label node classification. This override writes the full
``nlabels`` tensor instead, avoiding the extra RPC call that a super()-then-
re-fetch approach would require. The non-``DistFeature`` path (plain
``torch.Tensor`` labels) is unchanged — it never applied ``.T[0]``.
# TODO (mkolodner-sc): Now that GiGL owns this method, investigate whether
# post-processing steps in DistNeighborLoader._collate_fn can be folded in
# here and simplified — e.g. set_missing_features (populating empty tensors
# for node/edge features not fanned out to) and extract_metadata (stripping
# #META. keys before to_hetero_data to work around a GLT bug where those
# keys are misinterpreted as edge types).
Args:
output: The ``SamplerOutput`` or ``HeteroSamplerOutput`` returned by
``_sample_from_nodes``.
Returns:
A ``SampleMessage`` (``dict[str, torch.Tensor]``) ready to be sent
over the sampling channel or returned directly to the loader.
"""
result_map: SampleMessage = {}
is_hetero = self.dist_graph.data_cls == "hetero"
result_map["#IS_HETERO"] = torch.LongTensor([int(is_hetero)])
if isinstance(output.metadata, dict):
for k, v in output.metadata.items():
result_map[f"#META.{k}"] = v
if is_hetero:
for ntype, nodes in output.node.items():
result_map[f"{as_str(ntype)}.ids"] = nodes
if output.num_sampled_nodes is not None:
if ntype in output.num_sampled_nodes:
result_map[f"{as_str(ntype)}.num_sampled_nodes"] = torch.tensor(
output.num_sampled_nodes[ntype], device=self.device
)
for etype, rows in output.row.items():
etype_str = as_str(etype)
result_map[f"{etype_str}.rows"] = rows
result_map[f"{etype_str}.cols"] = output.col[etype]
if self.with_edge:
result_map[f"{etype_str}.eids"] = output.edge[etype]
if output.num_sampled_edges is not None:
if etype in output.num_sampled_edges:
result_map[f"{etype_str}.num_sampled_edges"] = torch.tensor(
output.num_sampled_edges[etype], device=self.device
)
input_type = output.input_type
assert input_type is not None
if not isinstance(input_type, tuple):
if self.dist_node_labels is not None:
if isinstance(self.dist_node_labels, DistFeature):
fut = self.dist_node_labels.async_get(
output.node[input_type], input_type
)
nlabels = await wrap_torch_future(fut)
# DistFeature always returns [N, K]. We collapse K=1 to 1-D
# [N] to match GLT's convention and what downstream code
# (e.g. CrossEntropyLoss) expects for data.y. Multi-label
# (K>1) keeps the full 2-D matrix.
# TODO (mkolodner-sc): Consider investigating always returning
# 2-D — this may be a breaking change for single-label
# training pipelines (e.g. CrossEntropyLoss expects 1-D data.y).
result_map[f"{as_str(input_type)}.nlabels"] = (
nlabels if nlabels.shape[1] > 1 else nlabels.T[0]
)
else:
node_labels = self.dist_node_labels.get(input_type, None)
if node_labels is not None:
result_map[f"{as_str(input_type)}.nlabels"] = node_labels[
output.node[input_type].to(node_labels.device)
]
if self.dist_node_feature is not None:
if self.use_all2all:
sorted_ntype = sorted(self.dist_node_feature.feature_pb.keys())
nfeat_dict = self.dist_node_feature.get_all2all(
output, sorted_ntype
)
for ntype, nfeats in nfeat_dict.items():
result_map[f"{as_str(ntype)}.nfeats"] = nfeats
else:
nfeat_fut_dict = {}
for ntype, nodes in output.node.items():
nodes = nodes.to(torch.long)
nfeat_fut_dict[ntype] = self.dist_node_feature.async_get(
nodes, ntype
)
for ntype, fut in nfeat_fut_dict.items():
nfeats = await wrap_torch_future(fut)
result_map[f"{as_str(ntype)}.nfeats"] = nfeats
if self.dist_edge_feature is not None and self.with_edge:
efeat_fut_dict = {}
for etype in self.edge_types:
if self.edge_dir == "in":
eids = result_map.get(
f"{as_str(reverse_edge_type(etype))}.eids", None
)
elif self.edge_dir == "out":
eids = result_map.get(f"{as_str(etype)}.eids", None)
if eids is not None:
eids = eids.to(torch.long)
efeat_fut_dict[etype] = self.dist_edge_feature.async_get(
eids, etype
)
for etype, fut in efeat_fut_dict.items():
efeats = await wrap_torch_future(fut)
if self.edge_dir == "out":
result_map[f"{as_str(etype)}.efeats"] = efeats
elif self.edge_dir == "in":
result_map[f"{as_str(reverse_edge_type(etype))}.efeats"] = (
efeats
)
if output.batch is not None:
for ntype, batch in output.batch.items():
result_map[f"{as_str(ntype)}.batch"] = batch
else:
result_map["ids"] = output.node
result_map["rows"] = output.row
result_map["cols"] = output.col
if output.num_sampled_nodes is not None:
result_map["num_sampled_nodes"] = torch.tensor(
output.num_sampled_nodes, device=self.device
)
result_map["num_sampled_edges"] = torch.tensor(
output.num_sampled_edges, device=self.device
)
if self.with_edge:
result_map["eids"] = output.edge
if self.dist_node_labels is not None:
if isinstance(self.dist_node_labels, DistFeature):
fut = self.dist_node_labels.async_get(output.node)
nlabels = await wrap_torch_future(fut)
# DistFeature always returns [N, K]. We collapse K=1 to 1-D
# [N] to match GLT's convention and what downstream code
# (e.g. CrossEntropyLoss) expects for data.y. Multi-label
# (K>1) keeps the full 2-D matrix.
# TODO (mkolodner-sc): Consider investigating always returning
# 2-D — this may be a breaking change for single-label
# training pipelines (e.g. CrossEntropyLoss expects 1-D data.y).
result_map["nlabels"] = (
nlabels if nlabels.shape[1] > 1 else nlabels.T[0]
)
else:
result_map["nlabels"] = self.dist_node_labels[
output.node.to(self.dist_node_labels.device)
]
if self.dist_node_feature is not None:
fut = self.dist_node_feature.async_get(output.node)
nfeats = await wrap_torch_future(fut)
result_map["nfeats"] = nfeats
if self.dist_edge_feature is not None:
eids = result_map["eids"]
fut = self.dist_edge_feature.async_get(eids)
efeats = await wrap_torch_future(fut)
result_map["efeats"] = efeats
if output.batch is not None:
result_map["batch"] = output.batch
return result_map
async def _sample_from_nodes(
self,
inputs: NodeSamplerInput,
) -> Union[SamplerOutput, HeteroSamplerOutput]:
"""Sample subgraph from seed nodes.
Subclasses must override this method with their specific sampling
strategy.
Args:
inputs: The seed nodes to sample from.
Raises:
NotImplementedError: Always — subclasses must override.
"""
raise NotImplementedError(
f"{type(self).__name__} must override _sample_from_nodes."
)