gigl.distributed.base_dist_loader#
Base distributed loader that consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader.
Subclasses GLT’s DistLoader and handles: - Dataset metadata storage - Colocated mode: DistLoader attribute setting + staggered producer init - Graph Store mode: barrier loop + async RPC dispatch + channel creation
Attributes#
Classes#
Base class for GiGL distributed loaders. |
|
Plain data container for resolved distributed context information. |
Module Contents#
- class gigl.distributed.base_dist_loader.BaseDistLoader(dataset, sampler_input, dataset_schema, worker_options, sampling_config, device, runtime, producer, sampler_options, process_start_gap_seconds=60.0, max_concurrent_producer_inits=None, non_blocking_transfers=True)[source]#
Bases:
graphlearn_torch.distributed.DistLoaderBase class for GiGL distributed loaders.
Consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader. Subclasses GLT’s DistLoader but does NOT call its
__init__— instead, it replicates the relevant attribute-setting logic to allow configurable producer classes.Subclasses should: 1. Call
resolve_runtime()to get runtime context. 2. Determine mode (colocated vs graph store). 3. Callcreate_sampling_config()to build the SamplingConfig. 4. For colocated: callcreate_colocated_channel()and construct theDistSamplingProducer(or subclass), then pass the producer asproducer.For graph store: pass the RPC function (e.g.
DistServer.create_sampling_producer) asproducer.Call
super().__init__()with the prepared data.
- Parameters:
dataset (Union[gigl.distributed.dist_dataset.DistDataset, gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset]) –
DistDataset(colocated) orRemoteDistDataset(graph store).sampler_input (Union[graphlearn_torch.sampler.NodeSamplerInput, list[graphlearn_torch.sampler.NodeSamplerInput]]) – Prepared by the subclass. Single input for colocated mode, list (one per server) for graph store mode.
dataset_schema (gigl.distributed.utils.neighborloader.DatasetSchema) – Contains edge types, feature info, edge dir, etc.
worker_options (Union[graphlearn_torch.distributed.MpDistSamplingWorkerOptions, graphlearn_torch.distributed.RemoteDistSamplingWorkerOptions]) –
MpDistSamplingWorkerOptions(colocated) orRemoteDistSamplingWorkerOptions(graph store).sampling_config (graphlearn_torch.sampler.SamplingConfig) – Configuration for sampling (created via
create_sampling_config).device (torch.device) – Target device for sampled results.
runtime (DistributedRuntimeInfo) – Resolved distributed runtime information.
producer (Union[gigl.distributed.dist_sampling_producer.DistSamplingProducer, Callable[Ellipsis, int]]) – Either a pre-constructed
DistSamplingProducer(colocated mode) or a callable to dispatch on theDistServer(graph store mode).sampler_options (gigl.distributed.sampler_options.SamplerOptions) – Controls which sampler class is instantiated.
process_start_gap_seconds (float) – Delay between each process for staggered colocated init. In graph store mode, this is the delay between each batch of concurrent producer initializations.
max_concurrent_producer_inits (Optional[int]) – Maximum number of leader ranks that may dispatch
create_producer_fnRPCs concurrently in graph store mode. Leaders are grouped into batches of this size; each batch sleepsbatch_index * process_start_gap_secondsbefore dispatching. Only applies to graph store mode. Defaults toNone(no staggering).non_blocking_transfers (bool)
- static create_colocated_channel(worker_options)[source]#
Creates a ShmChannel for colocated mode.
Creates and optionally pin-memories the shared-memory channel.
- Parameters:
worker_options (graphlearn_torch.distributed.MpDistSamplingWorkerOptions) – The colocated worker options (must already be fully configured).
- Returns:
A ShmChannel ready to be passed to a DistSamplingProducer.
- Return type:
graphlearn_torch.channel.ShmChannel
- static create_colocated_worker_options(*, dataset_num_partitions, num_workers, worker_concurrency, master_ip_address, master_port, channel_size, pin_memory)[source]#
Create worker options for colocated sampling workers.
- Parameters:
dataset_num_partitions (int) – Number of graph partitions in the colocated dataset.
num_workers (int) – Number of sampling worker processes.
worker_concurrency (int) – Max sampling concurrency per worker.
master_ip_address (str) – Master node IP address used by GLT RPC.
master_port (int) – Port for the GLT sampling worker group.
channel_size (str) – Shared-memory channel size.
pin_memory (bool) – Whether the output channel should be pinned.
- Returns:
Fully configured worker options for colocated sampling.
- Return type:
graphlearn_torch.distributed.MpDistSamplingWorkerOptions
- static create_graph_store_worker_options(*, dataset, compute_rank, worker_key, num_workers, worker_concurrency, channel_size, prefetch_size)[source]#
Create worker options for graph-store sampling workers.
- Parameters:
dataset (gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset) – Remote dataset proxy used to discover storage-cluster topology.
compute_rank (int) – Global compute-process rank for the current process.
worker_key (str) – Unique key used by the storage cluster to deduplicate producers.
num_workers (int) – Number of sampling worker processes.
worker_concurrency (int) – Max sampling concurrency per worker.
channel_size (str) – Remote shared-memory buffer size.
prefetch_size (int) – Max prefetched messages per storage server.
- Returns:
Fully configured worker options for graph-store sampling.
- Return type:
graphlearn_torch.distributed.RemoteDistSamplingWorkerOptions
- static create_sampling_config(num_neighbors, dataset_schema, batch_size=1, shuffle=False, drop_last=False)[source]#
Creates a SamplingConfig with patched fanout.
Patches
num_neighborsto zero-out label edge types, then creates the SamplingConfig used by both colocated and graph store modes.- Parameters:
num_neighbors (Union[list[int], dict[torch_geometric.typing.EdgeType, list[int]]]) – Fanout per hop.
dataset_schema (gigl.distributed.utils.neighborloader.DatasetSchema) – Contains edge types and edge dir.
batch_size (int) – How many samples per batch.
shuffle (bool) – Whether to shuffle input nodes.
drop_last (bool) – Whether to drop the last incomplete batch.
- Returns:
A fully configured SamplingConfig.
- Return type:
graphlearn_torch.sampler.SamplingConfig
- static initialize_colocated_sampling_worker(*, local_rank, local_world_size, node_rank, node_world_size, master_ip_address, device, num_cpu_threads)[source]#
Initialize the colocated GLT worker group for the current process.
- Parameters:
local_rank (int) – Local rank of the current process on this machine.
local_world_size (int) – Total number of local processes on this machine.
node_rank (int) – Rank of the current machine.
node_world_size (int) – Total number of machines in the cluster.
master_ip_address (str) – Master node IP address used for worker-group setup.
device (torch.device) – Device assigned to this loader process.
num_cpu_threads (Optional[int]) – Optional PyTorch CPU thread count override.
- Return type:
None
- static resolve_runtime(context=None, local_process_rank=None, local_process_world_size=None)[source]#
Resolves distributed context from either a DistributedContext or torch.distributed.
- Parameters:
context (Optional[gigl.distributed.dist_context.DistributedContext]) – (Deprecated) If provided, derives rank info from the DistributedContext. Requires local_process_rank and local_process_world_size.
local_process_rank (Optional[int]) – (Deprecated) Required when context is provided.
local_process_world_size (Optional[int]) – (Deprecated) Required when context is provided.
- Returns:
A DistributedRuntimeInfo containing all resolved rank/topology information.
- Return type: