gigl.distributed.base_dist_loader#

Base distributed loader that consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader.

Subclasses GLT’s DistLoader and handles: - Dataset metadata storage - Colocated mode: DistLoader attribute setting + staggered producer init - Graph Store mode: barrier loop + async RPC dispatch + channel creation

Attributes#

Classes#

BaseDistLoader

Base class for GiGL distributed loaders.

DistributedRuntimeInfo

Plain data container for resolved distributed context information.

Module Contents#

class gigl.distributed.base_dist_loader.BaseDistLoader(dataset, sampler_input, dataset_schema, worker_options, sampling_config, device, runtime, producer, sampler_options, process_start_gap_seconds=60.0, max_concurrent_producer_inits=None, non_blocking_transfers=True)[source]#

Bases: graphlearn_torch.distributed.DistLoader

Base class for GiGL distributed loaders.

Consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader. Subclasses GLT’s DistLoader but does NOT call its __init__ — instead, it replicates the relevant attribute-setting logic to allow configurable producer classes.

Subclasses should: 1. Call resolve_runtime() to get runtime context. 2. Determine mode (colocated vs graph store). 3. Call create_sampling_config() to build the SamplingConfig. 4. For colocated: call create_colocated_channel() and construct the

DistSamplingProducer (or subclass), then pass the producer as producer.

  1. For graph store: pass the RPC function (e.g. DistServer.create_sampling_producer) as producer.

  2. Call super().__init__() with the prepared data.

Parameters:
  • dataset (Union[gigl.distributed.dist_dataset.DistDataset, gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset]) – DistDataset (colocated) or RemoteDistDataset (graph store).

  • sampler_input (Union[graphlearn_torch.sampler.NodeSamplerInput, list[graphlearn_torch.sampler.NodeSamplerInput]]) – Prepared by the subclass. Single input for colocated mode, list (one per server) for graph store mode.

  • dataset_schema (gigl.distributed.utils.neighborloader.DatasetSchema) – Contains edge types, feature info, edge dir, etc.

  • worker_options (Union[graphlearn_torch.distributed.MpDistSamplingWorkerOptions, graphlearn_torch.distributed.RemoteDistSamplingWorkerOptions]) – MpDistSamplingWorkerOptions (colocated) or RemoteDistSamplingWorkerOptions (graph store).

  • sampling_config (graphlearn_torch.sampler.SamplingConfig) – Configuration for sampling (created via create_sampling_config).

  • device (torch.device) – Target device for sampled results.

  • runtime (DistributedRuntimeInfo) – Resolved distributed runtime information.

  • producer (Union[gigl.distributed.dist_sampling_producer.DistSamplingProducer, Callable[Ellipsis, int]]) – Either a pre-constructed DistSamplingProducer (colocated mode) or a callable to dispatch on the DistServer (graph store mode).

  • sampler_options (gigl.distributed.sampler_options.SamplerOptions) – Controls which sampler class is instantiated.

  • process_start_gap_seconds (float) – Delay between each process for staggered colocated init. In graph store mode, this is the delay between each batch of concurrent producer initializations.

  • max_concurrent_producer_inits (Optional[int]) – Maximum number of leader ranks that may dispatch create_producer_fn RPCs concurrently in graph store mode. Leaders are grouped into batches of this size; each batch sleeps batch_index * process_start_gap_seconds before dispatching. Only applies to graph store mode. Defaults to None (no staggering).

  • non_blocking_transfers (bool)

static create_colocated_channel(worker_options)[source]#

Creates a ShmChannel for colocated mode.

Creates and optionally pin-memories the shared-memory channel.

Parameters:

worker_options (graphlearn_torch.distributed.MpDistSamplingWorkerOptions) – The colocated worker options (must already be fully configured).

Returns:

A ShmChannel ready to be passed to a DistSamplingProducer.

Return type:

graphlearn_torch.channel.ShmChannel

static create_colocated_worker_options(*, dataset_num_partitions, num_workers, worker_concurrency, master_ip_address, master_port, channel_size, pin_memory)[source]#

Create worker options for colocated sampling workers.

Parameters:
  • dataset_num_partitions (int) – Number of graph partitions in the colocated dataset.

  • num_workers (int) – Number of sampling worker processes.

  • worker_concurrency (int) – Max sampling concurrency per worker.

  • master_ip_address (str) – Master node IP address used by GLT RPC.

  • master_port (int) – Port for the GLT sampling worker group.

  • channel_size (str) – Shared-memory channel size.

  • pin_memory (bool) – Whether the output channel should be pinned.

Returns:

Fully configured worker options for colocated sampling.

Return type:

graphlearn_torch.distributed.MpDistSamplingWorkerOptions

static create_graph_store_worker_options(*, dataset, compute_rank, worker_key, num_workers, worker_concurrency, channel_size, prefetch_size)[source]#

Create worker options for graph-store sampling workers.

Parameters:
  • dataset (gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset) – Remote dataset proxy used to discover storage-cluster topology.

  • compute_rank (int) – Global compute-process rank for the current process.

  • worker_key (str) – Unique key used by the storage cluster to deduplicate producers.

  • num_workers (int) – Number of sampling worker processes.

  • worker_concurrency (int) – Max sampling concurrency per worker.

  • channel_size (str) – Remote shared-memory buffer size.

  • prefetch_size (int) – Max prefetched messages per storage server.

Returns:

Fully configured worker options for graph-store sampling.

Return type:

graphlearn_torch.distributed.RemoteDistSamplingWorkerOptions

static create_sampling_config(num_neighbors, dataset_schema, batch_size=1, shuffle=False, drop_last=False)[source]#

Creates a SamplingConfig with patched fanout.

Patches num_neighbors to zero-out label edge types, then creates the SamplingConfig used by both colocated and graph store modes.

Parameters:
  • num_neighbors (Union[list[int], dict[torch_geometric.typing.EdgeType, list[int]]]) – Fanout per hop.

  • dataset_schema (gigl.distributed.utils.neighborloader.DatasetSchema) – Contains edge types and edge dir.

  • batch_size (int) – How many samples per batch.

  • shuffle (bool) – Whether to shuffle input nodes.

  • drop_last (bool) – Whether to drop the last incomplete batch.

Returns:

A fully configured SamplingConfig.

Return type:

graphlearn_torch.sampler.SamplingConfig

static initialize_colocated_sampling_worker(*, local_rank, local_world_size, node_rank, node_world_size, master_ip_address, device, num_cpu_threads)[source]#

Initialize the colocated GLT worker group for the current process.

Parameters:
  • local_rank (int) – Local rank of the current process on this machine.

  • local_world_size (int) – Total number of local processes on this machine.

  • node_rank (int) – Rank of the current machine.

  • node_world_size (int) – Total number of machines in the cluster.

  • master_ip_address (str) – Master node IP address used for worker-group setup.

  • device (torch.device) – Device assigned to this loader process.

  • num_cpu_threads (Optional[int]) – Optional PyTorch CPU thread count override.

Return type:

None

static resolve_runtime(context=None, local_process_rank=None, local_process_world_size=None)[source]#

Resolves distributed context from either a DistributedContext or torch.distributed.

Parameters:
  • context (Optional[gigl.distributed.dist_context.DistributedContext]) – (Deprecated) If provided, derives rank info from the DistributedContext. Requires local_process_rank and local_process_world_size.

  • local_process_rank (Optional[int]) – (Deprecated) Required when context is provided.

  • local_process_world_size (Optional[int]) – (Deprecated) Required when context is provided.

Returns:

A DistributedRuntimeInfo containing all resolved rank/topology information.

Return type:

DistributedRuntimeInfo

shutdown()[source]#
Return type:

None

batch_size[source]#
collect_features[source]#
drop_last[source]#
edge_dir[source]#
input_data[source]#
num_neighbors[source]#
sampling_config[source]#
sampling_type[source]#
shuffle[source]#
to_device[source]#
with_edge[source]#
with_weight[source]#
worker_options[source]#
class gigl.distributed.base_dist_loader.DistributedRuntimeInfo[source]#

Plain data container for resolved distributed context information.

local_rank: int[source]#
local_world_size: int[source]#
master_ip_address: str[source]#
node_rank: int[source]#
node_world_size: int[source]#
rank: int[source]#
should_cleanup_distributed_context: bool[source]#
world_size: int[source]#
gigl.distributed.base_dist_loader.DEFAULT_NUM_CPU_THREADS = 2[source]#
gigl.distributed.base_dist_loader.logger[source]#