import asyncio
from typing import Union
import torch
from graphlearn_torch.sampler import (
HeteroSamplerOutput,
NeighborOutput,
NodeSamplerInput,
SamplerOutput,
)
from graphlearn_torch.typing import EdgeType, NodeType
from graphlearn_torch.utils import count_dict, merge_dict, reverse_edge_type
from gigl.distributed.base_sampler import BaseDistNeighborSampler
[docs]
class DistNeighborSampler(BaseDistNeighborSampler):
"""GiGL's k-hop distributed neighbor sampler supporting both standard and ABLP inputs.
Extends BaseGiGLSampler (which provides shared input preparation utilities)
and overrides _sample_from_nodes with a k-hop neighbor sampling loop.
Supports both NodeSamplerInput (standard neighbor sampling) and
ABLPNodeSamplerInput (anchor-based link prediction with supervision nodes).
For ABLPNodeSamplerInput, supervision nodes (positive/negative labels) are
added to the sampling seeds, and label information is included in the output
metadata.
"""
async def _sample_from_nodes(
self,
inputs: NodeSamplerInput,
) -> Union[SamplerOutput, HeteroSamplerOutput]:
"""Sample subgraph from seed nodes using k-hop neighbor sampling.
Supports both NodeSamplerInput and ABLPNodeSamplerInput. For ABLP,
supervision nodes are included in sampling and label metadata is
attached to the output.
"""
sample_loop_inputs = self._prepare_sample_loop_inputs(inputs)
input_type = inputs.input_type
nodes_to_sample = sample_loop_inputs.nodes_to_sample
metadata = sample_loop_inputs.metadata
self.max_input_size: int = max(self.max_input_size, inputs.node.numel())
inducer = self._acquire_inducer()
is_hetero = self.dist_graph.data_cls == "hetero"
output: NeighborOutput
if is_hetero:
assert input_type is not None
assert isinstance(nodes_to_sample, dict)
out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {}
out_rows_hetero: dict[EdgeType, list[torch.Tensor]] = {}
out_cols_hetero: dict[EdgeType, list[torch.Tensor]] = {}
out_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {}
num_sampled_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {}
num_sampled_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {}
src_dict = inducer.init_node(nodes_to_sample)
# Use the original anchor seeds (inputs.node) for batch tracking,
# not the deduped nodes_to_sample. For ABLP, nodes_to_sample includes
# supervision nodes which should not be part of the batch.
batch = {input_type: inputs.node.to(self.device)}
merge_dict(src_dict, out_nodes_hetero)
count_dict(src_dict, num_sampled_nodes_hetero, 1)
for i in range(self.num_hops):
task_dict: dict[EdgeType, asyncio.Task] = {}
nbr_dict: dict[EdgeType, list[torch.Tensor]] = {}
edge_dict: dict[EdgeType, torch.Tensor] = {}
for etype in self.edge_types:
req_num = self.num_neighbors[etype][i]
if self.edge_dir == "in":
srcs = src_dict.get(etype[-1], None)
if srcs is not None and srcs.numel() > 0:
task_dict[reverse_edge_type(etype)] = (
self._loop.create_task(
self._sample_one_hop(srcs, req_num, etype)
)
)
elif self.edge_dir == "out":
srcs = src_dict.get(etype[0], None)
if srcs is not None and srcs.numel() > 0:
task_dict[etype] = self._loop.create_task(
self._sample_one_hop(srcs, req_num, etype)
)
for etype, task in task_dict.items():
output = await task
if output.nbr.numel() == 0:
continue
nbr_dict[etype] = [src_dict[etype[0]], output.nbr, output.nbr_num]
if output.edge is not None:
edge_dict[etype] = output.edge
if len(nbr_dict) == 0:
continue
nodes_dict, rows_dict, cols_dict = inducer.induce_next(nbr_dict)
merge_dict(nodes_dict, out_nodes_hetero)
merge_dict(rows_dict, out_rows_hetero)
merge_dict(cols_dict, out_cols_hetero)
merge_dict(edge_dict, out_edges_hetero)
count_dict(nodes_dict, num_sampled_nodes_hetero, i + 2)
count_dict(cols_dict, num_sampled_edges_hetero, i + 1)
src_dict = nodes_dict
sample_output = HeteroSamplerOutput(
node={
ntype: torch.cat(nodes) for ntype, nodes in out_nodes_hetero.items()
},
row={etype: torch.cat(rows) for etype, rows in out_rows_hetero.items()},
col={etype: torch.cat(cols) for etype, cols in out_cols_hetero.items()},
edge=(
{etype: torch.cat(eids) for etype, eids in out_edges_hetero.items()}
if self.with_edge
else None
),
batch=batch,
num_sampled_nodes=num_sampled_nodes_hetero,
num_sampled_edges=num_sampled_edges_hetero,
input_type=input_type,
metadata=metadata,
)
else:
assert input_type is None, (
f"Expected input_type to be None for homogeneous graph, got {input_type}"
)
assert isinstance(nodes_to_sample, torch.Tensor)
srcs = inducer.init_node(nodes_to_sample)
# Use the original anchor seeds (inputs.node) for batch tracking,
# not the deduped nodes_to_sample. For ABLP, nodes_to_sample includes
# supervision nodes which should not be part of the batch.
batch = inputs.node.to(self.device)
out_nodes: list[torch.Tensor] = []
out_edges: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = []
num_sampled_nodes: list[torch.Tensor] = []
num_sampled_edges: list[torch.Tensor] = []
out_nodes.append(srcs)
num_sampled_nodes.append(srcs.size(0))
for req_num in self.num_neighbors:
output = await self._sample_one_hop(srcs, req_num, None)
if output.nbr.numel() == 0:
break
nodes, rows, cols = inducer.induce_next(
srcs, output.nbr, output.nbr_num
)
out_nodes.append(nodes)
out_edges.append((rows, cols, output.edge))
num_sampled_nodes.append(nodes.size(0))
num_sampled_edges.append(cols.size(0))
srcs = nodes
if not out_edges:
sample_output = SamplerOutput(
node=torch.cat(out_nodes),
row=torch.empty(0, dtype=torch.long, device=self.device),
col=torch.empty(0, dtype=torch.long, device=self.device),
edge=(
torch.empty(0, dtype=torch.long, device=self.device)
if self.with_edge
else None
),
batch=batch,
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=metadata,
)
else:
sample_output = SamplerOutput(
node=torch.cat(out_nodes),
row=torch.cat([e[0] for e in out_edges]),
col=torch.cat([e[1] for e in out_edges]),
edge=(
torch.cat([e[2] for e in out_edges]) if self.with_edge else None
),
batch=batch,
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=metadata,
)
self.inducer_pool.put(inducer)
return sample_output