import asyncio
import gc
from collections import defaultdict
from dataclasses import dataclass
from typing import Union
import torch
from graphlearn_torch.distributed import DistNeighborSampler as GLTDistNeighborSampler
from graphlearn_torch.sampler import (
HeteroSamplerOutput,
NeighborOutput,
NodeSamplerInput,
SamplerOutput,
)
from graphlearn_torch.typing import EdgeType, NodeType
from graphlearn_torch.utils import count_dict, merge_dict, reverse_edge_type
from gigl.distributed.sampler import (
NEGATIVE_LABEL_METADATA_KEY,
POSITIVE_LABEL_METADATA_KEY,
ABLPNodeSamplerInput,
)
from gigl.utils.data_splitters import PADDING_NODE
def _stable_unique_preserve_order(nodes: torch.Tensor) -> torch.Tensor:
"""Return unique 1-D values while preserving first-occurrence order."""
if nodes.dim() != 1:
raise ValueError(
f"Expected a 1-D tensor of node ids, got shape {tuple(nodes.shape)}."
)
if nodes.numel() <= 1:
return nodes
unique_nodes, inverse = torch.unique(nodes, sorted=False, return_inverse=True)
first_positions = torch.full(
(unique_nodes.numel(),),
fill_value=nodes.numel(),
dtype=torch.long,
device=nodes.device,
)
positions = torch.arange(nodes.numel(), device=nodes.device)
first_positions.scatter_reduce_(
0,
inverse,
positions,
reduce="amin",
include_self=True,
)
stable_order = torch.argsort(first_positions)
return unique_nodes[stable_order]
@dataclass
[docs]
class DistNeighborSampler(GLTDistNeighborSampler):
"""GiGL's distributed neighbor sampler supporting both standard and ABLP inputs.
Extends GLT's DistNeighborSampler and overrides _sample_from_nodes to support
both NodeSamplerInput (standard neighbor sampling) and ABLPNodeSamplerInput
(anchor-based link prediction with supervision nodes).
For ABLPNodeSamplerInput, supervision nodes (positive/negative labels) are
added to the sampling seeds, and label information is included in the output
metadata.
"""
def _prepare_sample_loop_inputs(
self,
inputs: NodeSamplerInput,
) -> SampleLoopInputs:
"""Prepare inputs for the sampling loop.
Handles both standard NodeSamplerInput and ABLPNodeSamplerInput.
For ABLP inputs, adds supervision nodes to the sampling seeds and
builds label metadata.
Args:
inputs: Either a NodeSamplerInput or ABLPNodeSamplerInput.
Returns:
SampleLoopInputs containing the nodes to sample from and any
metadata related to the task (e.g., label tensors for ABLP).
"""
input_seeds = inputs.node.to(self.device)
input_type = inputs.input_type
if isinstance(inputs, ABLPNodeSamplerInput):
return self._prepare_ablp_inputs(inputs, input_seeds, input_type)
# For homogeneous graphs (input_type is None), return tensor directly.
# For heterogeneous graphs, return dict mapping node type to tensor.
if input_type is None:
return SampleLoopInputs(
nodes_to_sample=input_seeds,
metadata={},
)
return SampleLoopInputs(
nodes_to_sample={input_type: input_seeds},
metadata={},
)
def _prepare_ablp_inputs(
self,
inputs: ABLPNodeSamplerInput,
input_seeds: torch.Tensor,
input_type: NodeType,
) -> SampleLoopInputs:
"""Prepare ABLP inputs with supervision nodes and label metadata.
Args:
inputs: The ABLPNodeSamplerInput containing label information.
input_seeds: The anchor node seeds (already moved to device).
input_type: The node type of the anchor seeds.
Returns:
SampleLoopInputs with supervision nodes included in nodes_to_sample
and label tensors in metadata.
"""
# Since GLT swaps src/dst for edge_dir = "out",
# and GiGL assumes that supervision edge types are always
# (anchor_node_type, to, supervision_node_type),
# we need to index into supervision edge types accordingly.
label_edge_index = 0 if self.edge_dir == "in" else 2
# Build metadata and input nodes from positive/negative labels.
# We need to sample from the supervision nodes as well, and ensure
# that we are sampling from the correct node type.
metadata: dict[str, torch.Tensor] = {}
input_seeds_builder: dict[
Union[str, NodeType], list[torch.Tensor]
] = defaultdict(list)
input_seeds_builder[input_type].append(input_seeds)
for edge_type, label_tensor in inputs.positive_label_by_edge_types.items():
filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to(
self.device
)
input_seeds_builder[edge_type[label_edge_index]].append(
filtered_label_tensor
)
# Update the metadata per positive label edge type.
# We do this because GLT only supports dict[str, torch.Tensor] for metadata.
metadata[
f"{POSITIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}"
] = label_tensor
for edge_type, label_tensor in inputs.negative_label_by_edge_types.items():
filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to(
self.device
)
input_seeds_builder[edge_type[label_edge_index]].append(
filtered_label_tensor
)
# Update the metadata per negative label edge type.
# We do this because GLT only supports dict[str, torch.Tensor] for metadata.
metadata[
f"{NEGATIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}"
] = label_tensor
nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] = {
# Keep first-occurrence order so anchor seeds remain at the front of
# their node type; graph-transformer paths rely on that convention.
node_type: _stable_unique_preserve_order(
torch.cat(seeds, dim=0).to(self.device)
)
for node_type, seeds in input_seeds_builder.items()
}
# Memory cleanup — only del loop vars if any labels were processed
has_labels = bool(
inputs.positive_label_by_edge_types or inputs.negative_label_by_edge_types
)
if has_labels:
del filtered_label_tensor, label_tensor
for value in input_seeds_builder.values():
value.clear()
input_seeds_builder.clear()
del input_seeds_builder
gc.collect()
return SampleLoopInputs(
nodes_to_sample=nodes_to_sample,
metadata=metadata,
)
async def _sample_from_nodes(
self,
inputs: NodeSamplerInput,
) -> Union[SamplerOutput, HeteroSamplerOutput]:
"""Sample subgraph from seed nodes.
Supports both NodeSamplerInput and ABLPNodeSamplerInput. For ABLP,
supervision nodes are included in sampling and label metadata is
attached to the output.
"""
sample_loop_inputs = self._prepare_sample_loop_inputs(inputs)
input_type = inputs.input_type
nodes_to_sample = sample_loop_inputs.nodes_to_sample
metadata = sample_loop_inputs.metadata
self.max_input_size: int = max(self.max_input_size, inputs.node.numel())
inducer = self._acquire_inducer()
is_hetero = self.dist_graph.data_cls == "hetero"
output: NeighborOutput
if is_hetero:
assert input_type is not None
assert isinstance(nodes_to_sample, dict)
out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {}
out_rows_hetero: dict[EdgeType, list[torch.Tensor]] = {}
out_cols_hetero: dict[EdgeType, list[torch.Tensor]] = {}
out_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {}
num_sampled_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {}
num_sampled_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {}
src_dict = inducer.init_node(nodes_to_sample)
# Use the original anchor seeds (inputs.node) for batch tracking,
# not the deduped nodes_to_sample. For ABLP, nodes_to_sample includes
# supervision nodes which should not be part of the batch.
batch = {input_type: inputs.node.to(self.device)}
merge_dict(src_dict, out_nodes_hetero)
count_dict(src_dict, num_sampled_nodes_hetero, 1)
for i in range(self.num_hops):
task_dict: dict[EdgeType, asyncio.Task] = {}
nbr_dict: dict[EdgeType, list[torch.Tensor]] = {}
edge_dict: dict[EdgeType, torch.Tensor] = {}
for etype in self.edge_types:
req_num = self.num_neighbors[etype][i]
if self.edge_dir == "in":
srcs = src_dict.get(etype[-1], None)
if srcs is not None and srcs.numel() > 0:
task_dict[
reverse_edge_type(etype)
] = self._loop.create_task(
self._sample_one_hop(srcs, req_num, etype)
)
elif self.edge_dir == "out":
srcs = src_dict.get(etype[0], None)
if srcs is not None and srcs.numel() > 0:
task_dict[etype] = self._loop.create_task(
self._sample_one_hop(srcs, req_num, etype)
)
for etype, task in task_dict.items():
output = await task
if output.nbr.numel() == 0:
continue
nbr_dict[etype] = [src_dict[etype[0]], output.nbr, output.nbr_num]
if output.edge is not None:
edge_dict[etype] = output.edge
if len(nbr_dict) == 0:
continue
nodes_dict, rows_dict, cols_dict = inducer.induce_next(nbr_dict)
merge_dict(nodes_dict, out_nodes_hetero)
merge_dict(rows_dict, out_rows_hetero)
merge_dict(cols_dict, out_cols_hetero)
merge_dict(edge_dict, out_edges_hetero)
count_dict(nodes_dict, num_sampled_nodes_hetero, i + 2)
count_dict(cols_dict, num_sampled_edges_hetero, i + 1)
src_dict = nodes_dict
sample_output = HeteroSamplerOutput(
node={
ntype: torch.cat(nodes) for ntype, nodes in out_nodes_hetero.items()
},
row={etype: torch.cat(rows) for etype, rows in out_rows_hetero.items()},
col={etype: torch.cat(cols) for etype, cols in out_cols_hetero.items()},
edge=(
{etype: torch.cat(eids) for etype, eids in out_edges_hetero.items()}
if self.with_edge
else None
),
batch=batch,
num_sampled_nodes=num_sampled_nodes_hetero,
num_sampled_edges=num_sampled_edges_hetero,
input_type=input_type,
metadata=metadata,
)
else:
assert (
input_type is None
), f"Expected input_type to be None for homogeneous graph, got {input_type}"
assert isinstance(nodes_to_sample, torch.Tensor)
srcs = inducer.init_node(nodes_to_sample)
# Use the original anchor seeds (inputs.node) for batch tracking,
# not the deduped nodes_to_sample. For ABLP, nodes_to_sample includes
# supervision nodes which should not be part of the batch.
batch = inputs.node.to(self.device)
out_nodes: list[torch.Tensor] = []
out_edges: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = []
num_sampled_nodes: list[torch.Tensor] = []
num_sampled_edges: list[torch.Tensor] = []
out_nodes.append(srcs)
num_sampled_nodes.append(srcs.size(0))
for req_num in self.num_neighbors:
output = await self._sample_one_hop(srcs, req_num, None)
if output.nbr.numel() == 0:
break
nodes, rows, cols = inducer.induce_next(
srcs, output.nbr, output.nbr_num
)
out_nodes.append(nodes)
out_edges.append((rows, cols, output.edge))
num_sampled_nodes.append(nodes.size(0))
num_sampled_edges.append(cols.size(0))
srcs = nodes
if not out_edges:
sample_output = SamplerOutput(
node=torch.cat(out_nodes),
row=torch.empty(0, dtype=torch.long, device=self.device),
col=torch.empty(0, dtype=torch.long, device=self.device),
edge=(
torch.empty(0, dtype=torch.long, device=self.device)
if self.with_edge
else None
),
batch=batch,
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=metadata,
)
else:
sample_output = SamplerOutput(
node=torch.cat(out_nodes),
row=torch.cat([e[0] for e in out_edges]),
col=torch.cat([e[1] for e in out_edges]),
edge=(
torch.cat([e[2] for e in out_edges]) if self.with_edge else None
),
batch=batch,
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=metadata,
)
self.inducer_pool.put(inducer)
return sample_output