gigl.distributed.base_dist_loader#
Base distributed loader that consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader.
Subclasses GLT’s DistLoader and handles: - Dataset metadata storage - Colocated mode: DistLoader attribute setting + staggered producer init - Graph Store mode: barrier loop + async RPC dispatch + channel creation
Attributes#
Classes#
Base class for GiGL distributed loaders. |
|
Plain data container for resolved distributed context information. |
Module Contents#
- class gigl.distributed.base_dist_loader.BaseDistLoader(dataset, sampler_input, dataset_schema, worker_options, sampling_config, device, runtime, sampler, process_start_gap_seconds=60.0)[source]#
Bases:
graphlearn_torch.distributed.DistLoaderBase class for GiGL distributed loaders.
Consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader. Subclasses GLT’s DistLoader but does NOT call its
__init__— instead, it replicates the relevant attribute-setting logic to allow configurable producer classes.Subclasses should: 1. Call
resolve_runtime()to get runtime context. 2. Determine mode (colocated vs graph store). 3. Callcreate_sampling_config()to build the SamplingConfig. 4. For colocated: callcreate_colocated_channel()and construct theDistMpSamplingProducer(or subclass), then pass the producer assampler.For graph store: pass the RPC function (e.g.
DistServer.create_sampling_producer) assampler.Call
super().__init__()with the prepared data.
- Parameters:
dataset (Union[gigl.distributed.dist_dataset.DistDataset, gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset]) –
DistDataset(colocated) orRemoteDistDataset(graph store).sampler_input (Union[graphlearn_torch.sampler.NodeSamplerInput, list[graphlearn_torch.sampler.NodeSamplerInput]]) – Prepared by the subclass. Single input for colocated mode, list (one per server) for graph store mode.
dataset_schema (gigl.distributed.utils.neighborloader.DatasetSchema) – Contains edge types, feature info, edge dir, etc.
worker_options (Union[graphlearn_torch.distributed.MpDistSamplingWorkerOptions, graphlearn_torch.distributed.RemoteDistSamplingWorkerOptions]) –
MpDistSamplingWorkerOptions(colocated) orRemoteDistSamplingWorkerOptions(graph store).sampling_config (graphlearn_torch.sampler.SamplingConfig) – Configuration for the sampler (created via
create_sampling_config).device (torch.device) – Target device for sampled results.
runtime (DistributedRuntimeInfo) – Resolved distributed runtime information.
sampler (Union[graphlearn_torch.distributed.dist_sampling_producer.DistMpSamplingProducer, Callable[Ellipsis, int]]) – Either a pre-constructed
DistMpSamplingProducer(colocated mode) or a callable to dispatch on theDistServer(graph store mode).process_start_gap_seconds (float) – Delay between each process for staggered colocated init.
- static create_colocated_channel(worker_options)[source]#
Creates a ShmChannel for colocated mode.
Creates and optionally pin-memories the shared-memory channel.
- Parameters:
worker_options (graphlearn_torch.distributed.MpDistSamplingWorkerOptions) – The colocated worker options (must already be fully configured).
- Returns:
A ShmChannel ready to be passed to a DistMpSamplingProducer.
- Return type:
graphlearn_torch.channel.ShmChannel
- static create_sampling_config(num_neighbors, dataset_schema, batch_size=1, shuffle=False, drop_last=False)[source]#
Creates a SamplingConfig with patched fanout.
Patches
num_neighborsto zero-out label edge types, then creates the SamplingConfig used by both colocated and graph store modes.- Parameters:
num_neighbors (Union[list[int], dict[torch_geometric.typing.EdgeType, list[int]]]) – Fanout per hop.
dataset_schema (gigl.distributed.utils.neighborloader.DatasetSchema) – Contains edge types and edge dir.
batch_size (int) – How many samples per batch.
shuffle (bool) – Whether to shuffle input nodes.
drop_last (bool) – Whether to drop the last incomplete batch.
- Returns:
A fully configured SamplingConfig.
- Return type:
graphlearn_torch.sampler.SamplingConfig
- static resolve_runtime(context=None, local_process_rank=None, local_process_world_size=None)[source]#
Resolves distributed context from either a DistributedContext or torch.distributed.
- Parameters:
context (Optional[gigl.distributed.dist_context.DistributedContext]) – (Deprecated) If provided, derives rank info from the DistributedContext. Requires local_process_rank and local_process_world_size.
local_process_rank (Optional[int]) – (Deprecated) Required when context is provided.
local_process_world_size (Optional[int]) – (Deprecated) Required when context is provided.
- Returns:
A DistributedRuntimeInfo containing all resolved rank/topology information.
- Return type: