gigl.distributed.graph_store.shared_dist_sampling_producer#
Shared graph-store sampling backend and fair-queued worker loop.
This module implements the multi-channel sampling backend used in graph-store
mode. A single SharedDistSamplingBackend per loader instance manages a
pool of worker processes that service many compute-rank channels through a
fair-queued scheduler (_shared_sampling_worker_loop).
We need this “fair-queued” scheduler to ensure that each compute rank gets a fair share of the work. If we didn’t have this, then compute ranks with more data would starve the compute ranks with less data as sample_from_* calls would be blocked by the compute ranks with more data. Surprisingly, upping worker_concurrency does not fix this problem. TODO(kmonte): Look into why worker_concurrency does not fix this problem.
High-level architecture:
┌──────────────────────────────────────────────┐
│ SharedDistSamplingBackend │
│ (main process) │
├──────────────────────────────────────────────┤
│ register_input() │
│ start_new_epoch_sampling() │
│ is_channel_epoch_done() │
│ unregister_input() │
│ shutdown() │
└──────┬──────────────────────────────▲────────┘
│ task_queues │ event_queue
│ (SharedMpCommand, payload) │ (EPOCH_DONE_EVENT,
│ │ channel_id, epoch,
▼ │ worker_rank)
┌──────────────────────────────────────────────┐
│ Worker 0 .. N-1 │
│ _shared_sampling_worker_loop() │
│ │
│ ┌─────────────┐ sample_from_* ┌─────────┐ │
│ │ Sampler │───────────────▶│ Channel │ │
│ │ (per channel)│ (results) │ (output)│ │
│ └─────────────┘ └─────────┘ │
└──────────────────────────────────────────────┘
Worker event-loop internals:
┌─────────────────────────────────────────────────┐
│ Phase 1: Drain commands (non-blocking) │
│ task_queue.get_nowait() ──▶ _handle_command() │
│ REGISTER_INPUT ──▶ create sampler + state │
│ START_EPOCH ──▶ ActiveEpochState │
│ + enqueue to runnable │
│ UNREGISTER_INPUT ──▶ cleanup / defer │
│ STOP ──▶ exit loop │
├─────────────────────────────────────────────────┤
│ Phase 2: Round-robin batch submission │
│ for each channel in runnable_channel_ids: │
│ pop ──▶ _submit_one_batch() │
│ ──▶ sampler.sample_from_*() │
│ if more batches: re-enqueue channel │
│ │
│ completion callback (_on_batch_done): │
│ completed_batches += 1 │
│ if all done ──▶ EPOCH_DONE to event_queue │
├─────────────────────────────────────────────────┤
│ Phase 3: Idle wait │
│ if no commands and no batches submitted: │
│ task_queue.get(timeout=SCHEDULER_TICK_SECS) │
└─────────────────────────────────────────────────┘
Attributes#
Classes#
Mutable per-channel state for an in-progress epoch inside a worker. |
|
Payload for |
|
Shared graph-store sampling backend reused across many remote channels. |
|
Commands sent from the backend to worker subprocesses via task queues. |
|
Payload for |
Module Contents#
- class gigl.distributed.graph_store.shared_dist_sampling_producer.ActiveEpochState[source]#
Mutable per-channel state for an in-progress epoch inside a worker.
Created by
_handle_commandonSTART_EPOCHand removed when all batches complete.- seeds_index[source]#
Index tensor into the channel’s
sampler_input.Nonemeans sequential indices[0, input_len).
- submitted_batches[source]#
Number of batches submitted to the sampler so far. Mutated by
_submit_one_batch.
- completed_batches[source]#
Number of batches whose sampler callbacks have fired. Mutated by
_on_batch_done.
- class gigl.distributed.graph_store.shared_dist_sampling_producer.RegisterInputCmd[source]#
Payload for
SharedMpCommand.REGISTER_INPUT.Carries everything a worker needs to set up sampling for one channel.
- worker_key[source]#
Routing key used to identify this channel in the worker group (passed through to
create_dist_sampler).
- class gigl.distributed.graph_store.shared_dist_sampling_producer.SharedDistSamplingBackend(*, data, worker_options, sampling_config, sampler_options, degree_tensors)[source]#
Shared graph-store sampling backend reused across many remote channels.
Initialize the shared sampling backend.
Does not start worker processes — call
init_backendto spawn them.- Parameters:
data (graphlearn_torch.distributed.DistDataset) – The distributed dataset to sample from.
worker_options (graphlearn_torch.distributed.RemoteDistSamplingWorkerOptions) – GLT remote sampling worker configuration (RPC addresses, devices, concurrency).
sampling_config (graphlearn_torch.sampler.SamplingConfig) – Sampling parameters (batch size, neighbor counts, shuffle, etc.). All channels registered on this backend must use the same config.
sampler_options (gigl.distributed.sampler_options.SamplerOptions) – GiGL sampler variant configuration (e.g.
PPRSamplerOptionsfor PPR-based sampling).degree_tensors (Optional[Union[torch.Tensor, dict[graphlearn_torch.typing.EdgeType, torch.Tensor]]]) – Pre-computed degree tensors for PPR sampling (if applicable).
- describe_channel(channel_id)[source]#
Return lightweight diagnostics for one registered channel.
Drains pending completion events before building the snapshot.
- Parameters:
channel_id (int) – The channel to describe.
- Returns:
"epoch": Current epoch number (-1if never started)."input_sizes": Per-worker seed counts."completed_workers": Number of workers that finished the current epoch.
- Return type:
A dict with keys
- init_backend()[source]#
Initialize worker processes once for this backend.
Spawns
num_workerssubprocesses running_shared_sampling_worker_loop. Each worker initializes RPC and signals readiness via a shared barrier. This method blocks until all workers are ready.The initialization sequence is:
1. Assign devices and worker ranks from the GLT server context. 3. Spawn worker processes with per-worker task queues and a shared
event queue.
Wait on the barrier for all workers to finish RPC init.
No-op if already initialized.
- Raises:
RuntimeError – If no GLT server context is active.
- Return type:
None
- is_channel_epoch_done(channel_id, epoch)[source]#
Return whether every worker finished the epoch for one channel.
Drains pending completion events before checking.
- Parameters:
channel_id (int) – The channel to query.
epoch (int) – The epoch number to check.
- Returns:
Trueif allnum_workersworkers have reportedEPOCH_DONEfor this(channel_id, epoch)pair.- Raises:
RuntimeError – If any worker process has died.
- Return type:
bool
- register_input(channel_id, worker_key, sampler_input, sampling_config, channel)[source]#
Register a new channel on all backend workers.
Moves
sampler_inputinto shared memory, computes per-worker seed ranges, initializes shuffle state (if configured), and broadcasts aREGISTER_INPUTcommand to every worker.- Parameters:
channel_id (int) – Unique identifier for this channel.
worker_key (str) – Routing key for the channel in the worker group.
sampler_input (gigl.distributed.utils.dist_sampler.SamplerInput) – Seed node/edge inputs for this channel.
sampling_config (graphlearn_torch.sampler.SamplingConfig) – Must match the backend’s
sampling_config.channel (graphlearn_torch.channel.ChannelBase) – Output channel where sampled subgraphs are written.
- Raises:
RuntimeError – If the backend has not been initialized via
init_backend.ValueError – If
channel_idis already registered, or ifsampling_configdoes not match the backend config.
- Return type:
None
- shutdown()[source]#
Stop all worker processes and release backend resources.
Cleanup sequence:
Send
STOPto every worker’s task queue.Join each worker with a timeout of
MP_STATUS_CHECK_INTERVALseconds.Close all task queues and the event queue.
Terminate any workers still alive after the join timeout.
No-op if already shut down.
- Return type:
None
- start_new_epoch_sampling(channel_id, epoch)[source]#
Start a new sampling epoch for one registered channel.
Cleans up stale completion records, generates a shuffled or sequential seed permutation, slices it into per-worker ranges, and dispatches
START_EPOCHcommands to all workers.No-op if the channel has already started an epoch >=
epoch.Callers must ensure the previous epoch’s sampling has completed before starting a new one.
BaseDistLoader.__iter__guarantees this by consuming all batches via__next__before re-entering the loop.- Parameters:
channel_id (int) – The registered channel to start.
epoch (int) – Monotonically increasing epoch number.
- Raises:
KeyError – If
channel_idis not registered.- Return type:
None
- class gigl.distributed.graph_store.shared_dist_sampling_producer.SharedMpCommand(*args, **kwds)[source]#
Bases:
enum.EnumCommands sent from the backend to worker subprocesses via task queues.
Each command is paired with a payload in a
(command, payload)tuple placed on the per-workertask_queue.
- class gigl.distributed.graph_store.shared_dist_sampling_producer.StartEpochCmd[source]#
Payload for
SharedMpCommand.START_EPOCH.- epoch[source]#
Monotonically increasing epoch number. Duplicate or stale epochs are silently ignored by the worker.
- gigl.distributed.graph_store.shared_dist_sampling_producer.SCHEDULER_SLOW_SUBMIT_SECS = 1.0[source]#
- gigl.distributed.graph_store.shared_dist_sampling_producer.SCHEDULER_STATE_LOG_INTERVAL_SECS = 10.0[source]#