gigl.distributed.dist_neighbor_sampler#

Classes#

DistNeighborSampler

GiGL's distributed neighbor sampler supporting both standard and ABLP inputs.

SampleLoopInputs

Inputs prepared for the neighbor sampling loop in _sample_from_nodes.

Module Contents#

class gigl.distributed.dist_neighbor_sampler.DistNeighborSampler(data, num_neighbors=None, with_edge=False, with_neg=False, with_weight=False, edge_dir='out', collect_features=False, channel=None, use_all2all=False, concurrency=1, device=None, seed=None)[source]#

Bases: graphlearn_torch.distributed.DistNeighborSampler

GiGL’s distributed neighbor sampler supporting both standard and ABLP inputs.

Extends GLT’s DistNeighborSampler and overrides _sample_from_nodes to support both NodeSamplerInput (standard neighbor sampling) and ABLPNodeSamplerInput (anchor-based link prediction with supervision nodes).

For ABLPNodeSamplerInput, supervision nodes (positive/negative labels) are added to the sampling seeds, and label information is included in the output metadata.

Parameters:
  • data (graphlearn_torch.distributed.dist_dataset.DistDataset)

  • num_neighbors (Optional[graphlearn_torch.typing.NumNeighbors])

  • with_edge (bool)

  • with_neg (bool)

  • with_weight (bool)

  • edge_dir (Literal['in', 'out'])

  • collect_features (bool)

  • channel (Optional[graphlearn_torch.channel.ChannelBase])

  • use_all2all (bool)

  • concurrency (int)

  • device (Optional[torch.device])

  • seed (int)

class gigl.distributed.dist_neighbor_sampler.SampleLoopInputs[source]#

Inputs prepared for the neighbor sampling loop in _sample_from_nodes.

This dataclass holds the processed inputs that are passed to the core sampling loop. It allows _prepare_sample_loop_inputs to customize what nodes are sampled from and what metadata is attached to the output, without duplicating the sampling loop logic.

nodes_to_sample[source]#

For homogeneous graphs, a tensor of node IDs. For heterogeneous graphs, a dict mapping node types to tensors. For ABLP, this also includes supervision nodes (positive/negative labels).

metadata[source]#

Metadata dict to attach to the sampler output (e.g., label tensors).

metadata: dict[str, torch.Tensor][source]#
nodes_to_sample: torch.Tensor | dict[graphlearn_torch.typing.NodeType, torch.Tensor][source]#