Source code for gigl.distributed.dist_neighbor_sampler

import asyncio
from typing import Union

import torch
from graphlearn_torch.sampler import (
    HeteroSamplerOutput,
    NeighborOutput,
    NodeSamplerInput,
    SamplerOutput,
)
from graphlearn_torch.typing import EdgeType, NodeType
from graphlearn_torch.utils import count_dict, merge_dict, reverse_edge_type

from gigl.distributed.base_sampler import BaseDistNeighborSampler


[docs] class DistNeighborSampler(BaseDistNeighborSampler): """GiGL's k-hop distributed neighbor sampler supporting both standard and ABLP inputs. Extends BaseGiGLSampler (which provides shared input preparation utilities) and overrides _sample_from_nodes with a k-hop neighbor sampling loop. Supports both NodeSamplerInput (standard neighbor sampling) and ABLPNodeSamplerInput (anchor-based link prediction with supervision nodes). For ABLPNodeSamplerInput, supervision nodes (positive/negative labels) are added to the sampling seeds, and label information is included in the output metadata. """ async def _sample_from_nodes( self, inputs: NodeSamplerInput, ) -> Union[SamplerOutput, HeteroSamplerOutput]: """Sample subgraph from seed nodes using k-hop neighbor sampling. Supports both NodeSamplerInput and ABLPNodeSamplerInput. For ABLP, supervision nodes are included in sampling and label metadata is attached to the output. """ sample_loop_inputs = self._prepare_sample_loop_inputs(inputs) input_type = inputs.input_type nodes_to_sample = sample_loop_inputs.nodes_to_sample metadata = sample_loop_inputs.metadata self.max_input_size: int = max(self.max_input_size, inputs.node.numel()) inducer = self._acquire_inducer() is_hetero = self.dist_graph.data_cls == "hetero" output: NeighborOutput if is_hetero: assert input_type is not None assert isinstance(nodes_to_sample, dict) out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {} out_rows_hetero: dict[EdgeType, list[torch.Tensor]] = {} out_cols_hetero: dict[EdgeType, list[torch.Tensor]] = {} out_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {} num_sampled_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {} num_sampled_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {} src_dict = inducer.init_node(nodes_to_sample) # Use the original anchor seeds (inputs.node) for batch tracking, # not the deduped nodes_to_sample. For ABLP, nodes_to_sample includes # supervision nodes which should not be part of the batch. batch = {input_type: inputs.node.to(self.device)} merge_dict(src_dict, out_nodes_hetero) count_dict(src_dict, num_sampled_nodes_hetero, 1) for i in range(self.num_hops): task_dict: dict[EdgeType, asyncio.Task] = {} nbr_dict: dict[EdgeType, list[torch.Tensor]] = {} edge_dict: dict[EdgeType, torch.Tensor] = {} for etype in self.edge_types: req_num = self.num_neighbors[etype][i] if self.edge_dir == "in": srcs = src_dict.get(etype[-1], None) if srcs is not None and srcs.numel() > 0: task_dict[reverse_edge_type(etype)] = ( self._loop.create_task( self._sample_one_hop(srcs, req_num, etype) ) ) elif self.edge_dir == "out": srcs = src_dict.get(etype[0], None) if srcs is not None and srcs.numel() > 0: task_dict[etype] = self._loop.create_task( self._sample_one_hop(srcs, req_num, etype) ) for etype, task in task_dict.items(): output = await task if output.nbr.numel() == 0: continue nbr_dict[etype] = [src_dict[etype[0]], output.nbr, output.nbr_num] if output.edge is not None: edge_dict[etype] = output.edge if len(nbr_dict) == 0: continue nodes_dict, rows_dict, cols_dict = inducer.induce_next(nbr_dict) merge_dict(nodes_dict, out_nodes_hetero) merge_dict(rows_dict, out_rows_hetero) merge_dict(cols_dict, out_cols_hetero) merge_dict(edge_dict, out_edges_hetero) count_dict(nodes_dict, num_sampled_nodes_hetero, i + 2) count_dict(cols_dict, num_sampled_edges_hetero, i + 1) src_dict = nodes_dict sample_output = HeteroSamplerOutput( node={ ntype: torch.cat(nodes) for ntype, nodes in out_nodes_hetero.items() }, row={etype: torch.cat(rows) for etype, rows in out_rows_hetero.items()}, col={etype: torch.cat(cols) for etype, cols in out_cols_hetero.items()}, edge=( {etype: torch.cat(eids) for etype, eids in out_edges_hetero.items()} if self.with_edge else None ), batch=batch, num_sampled_nodes=num_sampled_nodes_hetero, num_sampled_edges=num_sampled_edges_hetero, input_type=input_type, metadata=metadata, ) else: assert input_type is None, ( f"Expected input_type to be None for homogeneous graph, got {input_type}" ) assert isinstance(nodes_to_sample, torch.Tensor) srcs = inducer.init_node(nodes_to_sample) # Use the original anchor seeds (inputs.node) for batch tracking, # not the deduped nodes_to_sample. For ABLP, nodes_to_sample includes # supervision nodes which should not be part of the batch. batch = inputs.node.to(self.device) out_nodes: list[torch.Tensor] = [] out_edges: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] num_sampled_nodes: list[torch.Tensor] = [] num_sampled_edges: list[torch.Tensor] = [] out_nodes.append(srcs) num_sampled_nodes.append(srcs.size(0)) for req_num in self.num_neighbors: output = await self._sample_one_hop(srcs, req_num, None) if output.nbr.numel() == 0: break nodes, rows, cols = inducer.induce_next( srcs, output.nbr, output.nbr_num ) out_nodes.append(nodes) out_edges.append((rows, cols, output.edge)) num_sampled_nodes.append(nodes.size(0)) num_sampled_edges.append(cols.size(0)) srcs = nodes if not out_edges: sample_output = SamplerOutput( node=torch.cat(out_nodes), row=torch.empty(0, dtype=torch.long, device=self.device), col=torch.empty(0, dtype=torch.long, device=self.device), edge=( torch.empty(0, dtype=torch.long, device=self.device) if self.with_edge else None ), batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, metadata=metadata, ) else: sample_output = SamplerOutput( node=torch.cat(out_nodes), row=torch.cat([e[0] for e in out_edges]), col=torch.cat([e[1] for e in out_edges]), edge=( torch.cat([e[2] for e in out_edges]) if self.with_edge else None ), batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, metadata=metadata, ) self.inducer_pool.put(inducer) return sample_output