gigl.distributed.base_sampler#

Classes#

BaseDistNeighborSampler

Base class for GiGL distributed samplers.

SampleLoopInputs

Inputs prepared for the neighbor sampling loop in _sample_from_nodes.

Module Contents#

class gigl.distributed.base_sampler.BaseDistNeighborSampler(data, num_neighbors=None, with_edge=False, with_neg=False, with_weight=False, edge_dir='out', collect_features=False, channel=None, use_all2all=False, concurrency=1, device=None, seed=None)[source]#

Bases: graphlearn_torch.distributed.DistNeighborSampler

Base class for GiGL distributed samplers.

Extends GLT’s DistNeighborSampler with shared utilities for preparing sampling inputs, including ABLP (anchor-based link prediction) support.

Subclasses must override _sample_from_nodes with their specific sampling strategy (e.g., k-hop neighbor sampling, PPR-based sampling).

Parameters:
  • data (graphlearn_torch.distributed.dist_dataset.DistDataset)

  • num_neighbors (Optional[graphlearn_torch.typing.NumNeighbors])

  • with_edge (bool)

  • with_neg (bool)

  • with_weight (bool)

  • edge_dir (Literal['in', 'out'])

  • collect_features (bool)

  • channel (Optional[graphlearn_torch.channel.ChannelBase])

  • use_all2all (bool)

  • concurrency (int)

  • device (Optional[torch.device])

  • seed (int)

class gigl.distributed.base_sampler.SampleLoopInputs[source]#

Inputs prepared for the neighbor sampling loop in _sample_from_nodes.

This dataclass holds the processed inputs that are passed to the core sampling loop. It allows _prepare_sample_loop_inputs to customize what nodes are sampled from and what metadata is attached to the output, without duplicating the sampling loop logic.

nodes_to_sample[source]#

For homogeneous graphs, a tensor of node IDs. For heterogeneous graphs, a dict mapping node types to tensors. For ABLP, this also includes supervision nodes (positive/negative labels).

metadata[source]#

Metadata dict to attach to the sampler output (e.g., label tensors).

metadata: dict[str, torch.Tensor][source]#
nodes_to_sample: torch.Tensor | dict[graphlearn_torch.typing.NodeType, torch.Tensor][source]#