gigl.distributed.graph_store.shared_dist_sampling_producer#

Shared graph-store sampling backend and fair-queued worker loop.

This module implements the multi-channel sampling backend used in graph-store mode. A single SharedDistSamplingBackend per loader instance manages a pool of worker processes that service many compute-rank channels through a fair-queued scheduler (_shared_sampling_worker_loop).

We need this “fair-queued” scheduler to ensure that each compute rank gets a fair share of the work. If we didn’t have this, then compute ranks with more data would starve the compute ranks with less data as sample_from_* calls would be blocked by the compute ranks with more data. Surprisingly, upping worker_concurrency does not fix this problem. TODO(kmonte): Look into why worker_concurrency does not fix this problem.

High-level architecture:

┌──────────────────────────────────────────────┐
│      SharedDistSamplingBackend               │
│      (main process)                          │
├──────────────────────────────────────────────┤
│  register_input()                            │
│  start_new_epoch_sampling()                  │
│  is_channel_epoch_done()                     │
│  unregister_input()                          │
│  shutdown()                                  │
└──────┬──────────────────────────────▲────────┘
       │ task_queues                  │ event_queue
       │ (SharedMpCommand, payload)   │ (EPOCH_DONE_EVENT,
       │                              │  channel_id, epoch,
       ▼                              │  worker_rank)
┌──────────────────────────────────────────────┐
│  Worker 0 .. N-1                             │
│  _shared_sampling_worker_loop()              │
│                                              │
│  ┌─────────────┐  sample_from_*  ┌─────────┐ │
│  │ Sampler      │───────────────▶│ Channel │ │
│  │ (per channel)│  (results)     │ (output)│ │
│  └─────────────┘                 └─────────┘ │
└──────────────────────────────────────────────┘

Worker event-loop internals:

┌─────────────────────────────────────────────────┐
│ Phase 1: Drain commands (non-blocking)          │
│   task_queue.get_nowait() ──▶ _handle_command() │
│     REGISTER_INPUT  ──▶ create sampler + state  │
│     START_EPOCH     ──▶ ActiveEpochState        │
│                          + enqueue to runnable  │
│     UNREGISTER_INPUT ──▶ cleanup / defer        │
│     STOP             ──▶ exit loop              │
├─────────────────────────────────────────────────┤
│ Phase 2: Round-robin batch submission           │
│   for each channel in runnable_channel_ids:     │
│     pop ──▶ _submit_one_batch()                 │
│            ──▶ sampler.sample_from_*()          │
│     if more batches: re-enqueue channel         │
│                                                 │
│   completion callback (_on_batch_done):         │
│     completed_batches += 1                      │
│     if all done ──▶ EPOCH_DONE to event_queue   │
├─────────────────────────────────────────────────┤
│ Phase 3: Idle wait                              │
│   if no commands and no batches submitted:      │
│     task_queue.get(timeout=SCHEDULER_TICK_SECS) │
└─────────────────────────────────────────────────┘

Attributes#

Classes#

ActiveEpochState

Mutable per-channel state for an in-progress epoch inside a worker.

RegisterInputCmd

Payload for SharedMpCommand.REGISTER_INPUT.

SharedDistSamplingBackend

Shared graph-store sampling backend reused across many remote channels.

SharedMpCommand

Commands sent from the backend to worker subprocesses via task queues.

StartEpochCmd

Payload for SharedMpCommand.START_EPOCH.

Module Contents#

class gigl.distributed.graph_store.shared_dist_sampling_producer.ActiveEpochState[source]#

Mutable per-channel state for an in-progress epoch inside a worker.

Created by _handle_command on START_EPOCH and removed when all batches complete.

channel_id[source]#

The channel this epoch belongs to.

epoch[source]#

The epoch number.

input_len[source]#

Total number of seed indices assigned to this worker for this epoch.

batch_size[source]#

Number of seeds per batch.

drop_last[source]#

If True, the final incomplete batch is skipped.

seeds_index[source]#

Index tensor into the channel’s sampler_input. None means sequential indices [0, input_len).

total_batches[source]#

Pre-computed number of batches for this epoch.

submitted_batches[source]#

Number of batches submitted to the sampler so far. Mutated by _submit_one_batch.

completed_batches[source]#

Number of batches whose sampler callbacks have fired. Mutated by _on_batch_done.

cancelled[source]#

Set to True when the channel is unregistered while batches are still in flight. Mutated by _clear_registered_input_locked.

batch_size: int[source]#
cancelled: bool = False[source]#
channel_id: int[source]#
completed_batches: int = 0[source]#
drop_last: bool[source]#
epoch: int[source]#
input_len: int[source]#
seeds_index: torch.Tensor | None[source]#
submitted_batches: int = 0[source]#
total_batches: int[source]#
class gigl.distributed.graph_store.shared_dist_sampling_producer.RegisterInputCmd[source]#

Payload for SharedMpCommand.REGISTER_INPUT.

Carries everything a worker needs to set up sampling for one channel.

channel_id[source]#

Unique identifier for this channel across the backend.

worker_key[source]#

Routing key used to identify this channel in the worker group (passed through to create_dist_sampler).

sampler_input[source]#

The full set of seed node/edge inputs for this channel, already in shared memory.

sampling_config[source]#

Sampling parameters (batch size, num neighbors, etc.).

channel[source]#

The output channel where sampled subgraphs are written.

channel: graphlearn_torch.channel.ChannelBase[source]#
channel_id: int[source]#
sampler_input: gigl.distributed.utils.dist_sampler.SamplerInput[source]#
sampling_config: graphlearn_torch.sampler.SamplingConfig[source]#
worker_key: str[source]#
class gigl.distributed.graph_store.shared_dist_sampling_producer.SharedDistSamplingBackend(*, data, worker_options, sampling_config, sampler_options, degree_tensors)[source]#

Shared graph-store sampling backend reused across many remote channels.

Initialize the shared sampling backend.

Does not start worker processes — call init_backend to spawn them.

Parameters:
  • data (graphlearn_torch.distributed.DistDataset) – The distributed dataset to sample from.

  • worker_options (graphlearn_torch.distributed.RemoteDistSamplingWorkerOptions) – GLT remote sampling worker configuration (RPC addresses, devices, concurrency).

  • sampling_config (graphlearn_torch.sampler.SamplingConfig) – Sampling parameters (batch size, neighbor counts, shuffle, etc.). All channels registered on this backend must use the same config.

  • sampler_options (gigl.distributed.sampler_options.SamplerOptions) – GiGL sampler variant configuration (e.g. PPRSamplerOptions for PPR-based sampling).

  • degree_tensors (Optional[Union[torch.Tensor, dict[graphlearn_torch.typing.EdgeType, torch.Tensor]]]) – Pre-computed degree tensors for PPR sampling (if applicable).

describe_channel(channel_id)[source]#

Return lightweight diagnostics for one registered channel.

Drains pending completion events before building the snapshot.

Parameters:

channel_id (int) – The channel to describe.

Returns:

  • "epoch": Current epoch number (-1 if never started).

  • "input_sizes": Per-worker seed counts.

  • "completed_workers": Number of workers that finished the current epoch.

Return type:

A dict with keys

init_backend()[source]#

Initialize worker processes once for this backend.

Spawns num_workers subprocesses running _shared_sampling_worker_loop. Each worker initializes RPC and signals readiness via a shared barrier. This method blocks until all workers are ready.

The initialization sequence is:

1. Assign devices and worker ranks from the GLT server context. 3. Spawn worker processes with per-worker task queues and a shared

event queue.

  1. Wait on the barrier for all workers to finish RPC init.

No-op if already initialized.

Raises:

RuntimeError – If no GLT server context is active.

Return type:

None

is_channel_epoch_done(channel_id, epoch)[source]#

Return whether every worker finished the epoch for one channel.

Drains pending completion events before checking.

Parameters:
  • channel_id (int) – The channel to query.

  • epoch (int) – The epoch number to check.

Returns:

True if all num_workers workers have reported EPOCH_DONE for this (channel_id, epoch) pair.

Raises:

RuntimeError – If any worker process has died.

Return type:

bool

register_input(channel_id, worker_key, sampler_input, sampling_config, channel)[source]#

Register a new channel on all backend workers.

Moves sampler_input into shared memory, computes per-worker seed ranges, initializes shuffle state (if configured), and broadcasts a REGISTER_INPUT command to every worker.

Parameters:
  • channel_id (int) – Unique identifier for this channel.

  • worker_key (str) – Routing key for the channel in the worker group.

  • sampler_input (gigl.distributed.utils.dist_sampler.SamplerInput) – Seed node/edge inputs for this channel.

  • sampling_config (graphlearn_torch.sampler.SamplingConfig) – Must match the backend’s sampling_config.

  • channel (graphlearn_torch.channel.ChannelBase) – Output channel where sampled subgraphs are written.

Raises:
  • RuntimeError – If the backend has not been initialized via init_backend.

  • ValueError – If channel_id is already registered, or if sampling_config does not match the backend config.

Return type:

None

shutdown()[source]#

Stop all worker processes and release backend resources.

Cleanup sequence:

  1. Send STOP to every worker’s task queue.

  2. Join each worker with a timeout of MP_STATUS_CHECK_INTERVAL seconds.

  3. Close all task queues and the event queue.

  4. Terminate any workers still alive after the join timeout.

No-op if already shut down.

Return type:

None

start_new_epoch_sampling(channel_id, epoch)[source]#

Start a new sampling epoch for one registered channel.

Cleans up stale completion records, generates a shuffled or sequential seed permutation, slices it into per-worker ranges, and dispatches START_EPOCH commands to all workers.

No-op if the channel has already started an epoch >= epoch.

Callers must ensure the previous epoch’s sampling has completed before starting a new one. BaseDistLoader.__iter__ guarantees this by consuming all batches via __next__ before re-entering the loop.

Parameters:
  • channel_id (int) – The registered channel to start.

  • epoch (int) – Monotonically increasing epoch number.

Raises:

KeyError – If channel_id is not registered.

Return type:

None

unregister_input(channel_id)[source]#

Unregister a channel from all backend workers.

Removes backend-side bookkeeping and broadcasts UNREGISTER_INPUT to every worker.

No-op if channel_id is not currently registered.

Parameters:

channel_id (int) – The channel to remove.

Return type:

None

data[source]#
num_workers[source]#
worker_options[source]#
class gigl.distributed.graph_store.shared_dist_sampling_producer.SharedMpCommand(*args, **kwds)[source]#

Bases: enum.Enum

Commands sent from the backend to worker subprocesses via task queues.

Each command is paired with a payload in a (command, payload) tuple placed on the per-worker task_queue.

REGISTER_INPUT[source]#

Register a new channel with its sampler input, sampling config, and output channel. Payload: RegisterInputCmd.

UNREGISTER_INPUT[source]#

Remove a channel and clean up its state. Payload: int (the channel_id).

START_EPOCH[source]#

Begin sampling a new epoch for one channel. Payload: StartEpochCmd.

STOP[source]#

Shut down the worker process. Payload: None.

REGISTER_INPUT[source]#
START_EPOCH[source]#
STOP[source]#
UNREGISTER_INPUT[source]#
class gigl.distributed.graph_store.shared_dist_sampling_producer.StartEpochCmd[source]#

Payload for SharedMpCommand.START_EPOCH.

channel_id[source]#

The channel whose epoch is starting.

epoch[source]#

Monotonically increasing epoch number. Duplicate or stale epochs are silently ignored by the worker.

seeds_index[source]#

Index tensor selecting which seeds from the channel’s sampler_input to sample this epoch. None means use the full input range.

channel_id: int[source]#
epoch: int[source]#
seeds_index: torch.Tensor | None[source]#
gigl.distributed.graph_store.shared_dist_sampling_producer.CommandPayload[source]#
gigl.distributed.graph_store.shared_dist_sampling_producer.EPOCH_DONE_EVENT = 'EPOCH_DONE'[source]#
gigl.distributed.graph_store.shared_dist_sampling_producer.SCHEDULER_SLOW_SUBMIT_SECS = 1.0[source]#
gigl.distributed.graph_store.shared_dist_sampling_producer.SCHEDULER_STATE_LOG_INTERVAL_SECS = 10.0[source]#
gigl.distributed.graph_store.shared_dist_sampling_producer.SCHEDULER_STATE_MAX_CHANNELS = 6[source]#
gigl.distributed.graph_store.shared_dist_sampling_producer.SCHEDULER_TICK_SECS = 0.05[source]#
gigl.distributed.graph_store.shared_dist_sampling_producer.logger[source]#