gigl.distributed.base_dist_loader#

Base distributed loader that consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader.

Subclasses GLT’s DistLoader and handles: - Dataset metadata storage - Colocated mode: DistLoader attribute setting + staggered producer init - Graph Store mode: barrier loop + async RPC dispatch + channel creation

Attributes#

Classes#

BaseDistLoader

Base class for GiGL distributed loaders.

DistributedRuntimeInfo

Plain data container for resolved distributed context information.

Module Contents#

class gigl.distributed.base_dist_loader.BaseDistLoader(dataset, sampler_input, dataset_schema, worker_options, sampling_config, device, runtime, sampler, process_start_gap_seconds=60.0)[source]#

Bases: graphlearn_torch.distributed.DistLoader

Base class for GiGL distributed loaders.

Consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader. Subclasses GLT’s DistLoader but does NOT call its __init__ — instead, it replicates the relevant attribute-setting logic to allow configurable producer classes.

Subclasses should: 1. Call resolve_runtime() to get runtime context. 2. Determine mode (colocated vs graph store). 3. Call create_sampling_config() to build the SamplingConfig. 4. For colocated: call create_colocated_channel() and construct the

DistMpSamplingProducer (or subclass), then pass the producer as sampler.

  1. For graph store: pass the RPC function (e.g. DistServer.create_sampling_producer) as sampler.

  2. Call super().__init__() with the prepared data.

Parameters:
  • dataset (Union[gigl.distributed.dist_dataset.DistDataset, gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset]) – DistDataset (colocated) or RemoteDistDataset (graph store).

  • sampler_input (Union[graphlearn_torch.sampler.NodeSamplerInput, list[graphlearn_torch.sampler.NodeSamplerInput]]) – Prepared by the subclass. Single input for colocated mode, list (one per server) for graph store mode.

  • dataset_schema (gigl.distributed.utils.neighborloader.DatasetSchema) – Contains edge types, feature info, edge dir, etc.

  • worker_options (Union[graphlearn_torch.distributed.MpDistSamplingWorkerOptions, graphlearn_torch.distributed.RemoteDistSamplingWorkerOptions]) – MpDistSamplingWorkerOptions (colocated) or RemoteDistSamplingWorkerOptions (graph store).

  • sampling_config (graphlearn_torch.sampler.SamplingConfig) – Configuration for the sampler (created via create_sampling_config).

  • device (torch.device) – Target device for sampled results.

  • runtime (DistributedRuntimeInfo) – Resolved distributed runtime information.

  • sampler (Union[graphlearn_torch.distributed.dist_sampling_producer.DistMpSamplingProducer, Callable[Ellipsis, int]]) – Either a pre-constructed DistMpSamplingProducer (colocated mode) or a callable to dispatch on the DistServer (graph store mode).

  • process_start_gap_seconds (float) – Delay between each process for staggered colocated init.

static create_colocated_channel(worker_options)[source]#

Creates a ShmChannel for colocated mode.

Creates and optionally pin-memories the shared-memory channel.

Parameters:

worker_options (graphlearn_torch.distributed.MpDistSamplingWorkerOptions) – The colocated worker options (must already be fully configured).

Returns:

A ShmChannel ready to be passed to a DistMpSamplingProducer.

Return type:

graphlearn_torch.channel.ShmChannel

static create_sampling_config(num_neighbors, dataset_schema, batch_size=1, shuffle=False, drop_last=False)[source]#

Creates a SamplingConfig with patched fanout.

Patches num_neighbors to zero-out label edge types, then creates the SamplingConfig used by both colocated and graph store modes.

Parameters:
  • num_neighbors (Union[list[int], dict[torch_geometric.typing.EdgeType, list[int]]]) – Fanout per hop.

  • dataset_schema (gigl.distributed.utils.neighborloader.DatasetSchema) – Contains edge types and edge dir.

  • batch_size (int) – How many samples per batch.

  • shuffle (bool) – Whether to shuffle input nodes.

  • drop_last (bool) – Whether to drop the last incomplete batch.

Returns:

A fully configured SamplingConfig.

Return type:

graphlearn_torch.sampler.SamplingConfig

static resolve_runtime(context=None, local_process_rank=None, local_process_world_size=None)[source]#

Resolves distributed context from either a DistributedContext or torch.distributed.

Parameters:
  • context (Optional[gigl.distributed.dist_context.DistributedContext]) – (Deprecated) If provided, derives rank info from the DistributedContext. Requires local_process_rank and local_process_world_size.

  • local_process_rank (Optional[int]) – (Deprecated) Required when context is provided.

  • local_process_world_size (Optional[int]) – (Deprecated) Required when context is provided.

Returns:

A DistributedRuntimeInfo containing all resolved rank/topology information.

Return type:

DistributedRuntimeInfo

shutdown()[source]#
Return type:

None

batch_size[source]#
collect_features[source]#
drop_last[source]#
edge_dir[source]#
input_data[source]#
num_neighbors[source]#
sampling_config[source]#
sampling_type[source]#
shuffle[source]#
to_device[source]#
with_edge[source]#
with_weight[source]#
worker_options[source]#
class gigl.distributed.base_dist_loader.DistributedRuntimeInfo[source]#

Plain data container for resolved distributed context information.

local_rank: int[source]#
local_world_size: int[source]#
master_ip_address: str[source]#
node_rank: int[source]#
node_world_size: int[source]#
rank: int[source]#
should_cleanup_distributed_context: bool[source]#
world_size: int[source]#
gigl.distributed.base_dist_loader.logger[source]#