gigl.distributed.graph_store.dist_server#

GiGL implementation of GLT DistServer.

Uses GiGL’s DistSamplingProducer which supports neighbor sampling and ABLP (anchor-based link prediction) via BaseGiGLSampler subclasses (DistNeighborSampler for k-hop, DistPPRNeighborSampler for PPR).

Based on alibaba/graphlearn-for-pytorch

Teardown protocol#

Cluster teardown is a strict three-phase sequence:

  1. Per-loader (compute side). Each DistLoader.shutdown() issues DistServer.destroy_sampling_input(channel_id) against every storage server it registered with. This call performs the actual teardown work on the server: runtime.unregister_input for the channel and, when the channel is the last on its backend, runtime.shutdown for the backend.

  2. Per-compute-process. After all loaders are torn down, the compute process calls gigl.distributed.graph_store.compute.shutdown_compute_process, which calls glt.distributed.shutdown_client and tears down the compute torch process group.

  3. Per-storage-process. wait_and_shutdown_server blocks until DistServer.exit flips, then runs DistServer.shutdown() (strict validation — see below), barrier(), shutdown_rpc().

An example fo the teardown protocol: ```python def compute_process():

train_loader = … val_loader = … test_loader = …

loaders = [train_loader, val_loader, test_loader] for data in train_loader:

train(data) if should_val():

for data in val_loader:

val(data)

# Shutdown the loaders after training and validation. train_loader.shutdown() val_loader.shutdown()

for data in test_loader:

test(data)

# Shutdown the loader after testing. test_loader.shutdown()

# Step 2: Per-compute-process shutdown_compute_process()

# Step 3: Per-storage-process wait_and_shutdown_server() ```

DistServer.shutdown() does no teardown work itself. It validates that phase 1 ran for every channel/backend, raises RuntimeError if state remains, and drops residual tombstone/stats bookkeeping. The actual runtime.shutdown() work happens in destroy_sampling_input when the last channel on a backend is destroyed. wait_and_shutdown_server catches DistServer.shutdown exceptions so a buggy compute client cannot wedge healthy storage peers on the barrier; the failing storage process re-raises after the barrier so the orchestrator sees a non-zero exit.

Attributes#

FETCH_SLOW_LOG_SECS

R

SERVER_EXIT_STATUS_CHECK_INTERVAL

Interval (in seconds) to check exit status of server.

logger

Classes#

ChannelState

Per-channel state for a registered sampling input.

DistServer

A server that supports launching remote sampling workers for

SamplingBackendState

Per-backend state for a shared sampling backend.

Functions#

get_server()

Get the DistServer instance on the current process.

init_server(num_servers, server_rank, dataset, ...[, ...])

Initialize the current process as a server and establish connections

wait_and_shutdown_server()

Block until all clients have shut down, then shut down the

Module Contents#

class gigl.distributed.graph_store.dist_server.ChannelState[source]#

Per-channel state for a registered sampling input.

Parameters:
  • backend_id – The ID of the backend this channel belongs to.

  • worker_key – The unique key identifying this compute-rank channel.

  • channel – The shared-memory channel for passing sampled messages.

  • epoch – The last epoch started on this channel.

  • lock – A reentrant lock guarding channel-level operations.

  • tombstoned – Terminal-state tombstone. Set to True once destroy_sampling_input has finished cleaning up this channel. The channel object may briefly outlive its registry entry — an in-flight fetch_one_sampled_message on another RPC thread reads this flag to break out of its recv loop. Past-participle name: this is “done” state, not “in progress” state. Asymmetric with SamplingBackendState.tearing_down by design — backends are removed from the registry when teardown completes, so no terminal flag is needed there; channels are not removed until the in-flight fetch observes the tombstone, so the flag has to live on the object.

backend_id: int[source]#
channel: graphlearn_torch.channel.ShmChannel[source]#
epoch: int = -1[source]#
lock: threading.RLock[source]#
tombstoned: bool = False[source]#
worker_key: str[source]#
class gigl.distributed.graph_store.dist_server.DistServer(dataset, log_every_n=50)[source]#

A server that supports launching remote sampling workers for training clients.

Note that this server is enabled only when the distribution mode is a server-client framework, and the graph and feature store will be partitioned and managed by all server nodes.

Parameters:
  • dataset (DistDataset) – The DistDataset object of a partition of graph data and feature data, along with distributed partition books.

  • log_every_n (int) – Log aggregated fetch_one_sampled_message timing stats after every N fetch attempts per channel.

destroy_sampling_input(channel_id)[source]#

Destroy one registered sampling channel and maybe its backend.

If this is the last channel on the backend, the backend is shut down and removed.

Parameters:

channel_id (int) – The ID of the channel to destroy.

Return type:

None

exit()[source]#

Set the exit flag to True.

Return type:

bool

fetch_one_sampled_message(channel_id)[source]#

Fetch one sampled message from a registered channel.

Parameters:

channel_id (int) – The ID of the channel to fetch from.

Returns:

A tuple of (message, is_done). If is_done is True, no more messages will be produced for this epoch.

Return type:

tuple[Optional[graphlearn_torch.channel.SampleMessage], bool]

get_ablp_input(request)[source]#

Get the ABLP (Anchor Based Link Prediction) input for distributed processing.

Parameters:

request (gigl.distributed.graph_store.messages.FetchABLPInputRequest) – The ABLP fetch request, including split, node type, supervision edge type, and an optional contiguous server slice.

Returns:

A tuple containing the anchor nodes for the rank, the positive labels, and the negative labels. The positive labels are of shape [N, M], where N is the number of anchor nodes and M is the number of positive labels. The negative labels are of shape [N, M], where N is the number of anchor nodes and M is the number of negative labels. The negative labels may be None if no negative labels are available.

Raises:

ValueError – If the split is invalid.

Return type:

tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]

get_dataset_meta()[source]#

Get the meta info of the distributed dataset managed by the current server, including partition info and graph types.

Return type:

tuple[int, int, Optional[list[gigl.src.common.types.graph_data.NodeType]], Optional[list[gigl.src.common.types.graph_data.EdgeType]]]

get_edge_dir()[source]#

Get the edge direction from the dataset.

Returns:

The edge direction.

Return type:

Literal[‘in’, ‘out’]

get_edge_feature_info()[source]#

Get edge feature information from the dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping EdgeType to FeatureInfo for heterogeneous graphs

  • None if no edge features are available

Return type:

Edge feature information, which can be

get_edge_index(edge_type, layout)[source]#
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]

get_edge_partition_book(edge_type)[source]#

Gets the partition book for the specified edge type. :param edge_type: The edge type to look up. Must be None for

homogeneous datasets and non-None for heterogeneous ones.

Returns:

The partition book for the requested edge type, or None if no partition book is available.

Raises:

ValueError – If edge_type is mismatched with the dataset type.

Parameters:

edge_type (Optional[gigl.src.common.types.graph_data.EdgeType])

Return type:

Optional[graphlearn_torch.partition.PartitionBook]

get_edge_size(edge_type, layout)[source]#
Parameters:
Return type:

tuple[int, int]

get_edge_types()[source]#

Get the edge types from the dataset.

Returns:

The edge types in the dataset, None if the dataset is homogeneous.

Return type:

Optional[list[gigl.src.common.types.graph_data.EdgeType]]

get_node_feature(node_type, index)[source]#
Parameters:
  • node_type (Optional[gigl.src.common.types.graph_data.NodeType])

  • index (torch.Tensor)

Return type:

torch.Tensor

get_node_feature_info()[source]#

Get node feature information from the dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping NodeType to FeatureInfo for heterogeneous graphs

  • None if no node features are available

Return type:

Node feature information, which can be

get_node_ids(request)[source]#

Get the node ids from the dataset.

Parameters:

request (gigl.distributed.graph_store.messages.FetchNodesRequest) – The node-fetch request, including split, node type, and an optional contiguous server slice.

Returns:

The node ids.

Raises:
  • ValueError

    • If the split is invalid

    • If the node ids are not a torch.Tensor or a dict[NodeType, torch.Tensor]

    • If the node type is provided for a homogeneous dataset

    • If the node ids are not a dict[NodeType, torch.Tensor] when no node type is provided

  • Note – When split=None, all nodes are queryable. This means nodes from any

  • split (train, val, or test) may be returned. This is useful when you need

  • to sample neighbors during inference, as neighbor nodes may belong to any split.

Return type:

torch.Tensor

get_node_label(node_type, index)[source]#
Parameters:
  • node_type (Optional[gigl.src.common.types.graph_data.NodeType])

  • index (torch.Tensor)

Return type:

torch.Tensor

get_node_partition_book(node_type)[source]#

Gets the partition book for the specified node type.

Parameters:

node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – The node type to look up. Must be None for homogeneous datasets and non-None for heterogeneous ones.

Returns:

The partition book for the requested node type, or None if no partition book is available.

Raises:

ValueError – If node_type is mismatched with the dataset type.

Return type:

Optional[graphlearn_torch.partition.PartitionBook]

get_node_partition_id(node_type, index)[source]#
Parameters:
  • node_type (Optional[gigl.src.common.types.graph_data.NodeType])

  • index (torch.Tensor)

Return type:

Optional[torch.Tensor]

get_node_types()[source]#

Get the node types from the dataset.

Returns:

The node types in the dataset, None if the dataset is homogeneous.

Return type:

Optional[list[gigl.src.common.types.graph_data.NodeType]]

get_tensor_size(node_type)[source]#
Parameters:

node_type (Optional[gigl.src.common.types.graph_data.NodeType])

Return type:

torch.Size

init_sampling_backend(opts)[source]#

Create or reuse a shared sampling backend for one loader instance.

If a backend with the same backend_key already exists, blocks until the original initializer has finished, then returns its ID. If the original initializer failed, re-raises that failure so the second caller does not see a half-initialized backend.

Parameters:

opts (gigl.distributed.graph_store.messages.InitSamplingBackendRequest) – The initialization request containing the backend key, worker options, sampler options, and sampling config.

Returns:

The unique backend ID.

Raises:

RuntimeError – If a prior concurrent initialization for the same backend_key failed, or if the backend is currently being torn down by destroy_sampling_input.

Return type:

int

register_sampling_input(opts)[source]#

Register one compute-rank input channel on an existing backend.

Parameters:

opts (gigl.distributed.graph_store.messages.RegisterBackendRequest) – The registration request containing the backend ID, worker key, sampler input, sampling config, and buffer settings.

Returns:

The unique channel ID for this input.

Raises:
  • KeyError – If opts.backend_id does not refer to a registered backend (caller bug; backend must be registered first).

  • Exception – Re-raises any failure from runtime.register_input after rolling back the partial channel state.

Return type:

int

shutdown()[source]#

Final post-teardown bookkeeping cleanup — does not perform any shutdown.

Lenient contract: callers are expected to have torn down every channel and backend via destroy_sampling_input before invoking this method. The actual runtime.shutdown() work happens in destroy_sampling_input when the last channel on a backend is destroyed; this method only drops residual tombstone/stats bookkeeping and warns if any sampling state is still registered.

Return type:

None

start_new_epoch_sampling(channel_id, epoch)[source]#

Start one new epoch on one registered channel.

No-op if this channel has already started epoch or a later epoch (idempotent — safe to call repeatedly from retries).

No-op if the channel is in the tombstoned set: this handles the legitimate destroy/start race window where a compute peer’s start RPC arrives after destroy has landed.

Parameters:
  • channel_id (int) – The ID of the channel to start the epoch on.

  • epoch (int) – The epoch number to start.

Raises:

RuntimeError – If channel_id was never registered on this server (vs. legitimately tombstoned — tombstoned ids are treated as a silent no-op).

Return type:

None

wait_for_exit()[source]#

Block until the exit flag been set to True.

Return type:

None

dataset[source]#
class gigl.distributed.graph_store.dist_server.SamplingBackendState[source]#

Per-backend state for a shared sampling backend.

Parameters:
  • backend_id – The unique ID of this backend.

  • backend_key – The key identifying this backend (e.g. "dist_neighbor_loader_0").

  • runtime – The shared sampling backend runtime.

  • active_channels – Set of channel IDs currently registered on this backend.

  • lock – A reentrant lock guarding backend-level operations.

  • init_complete – Whether runtime.init_backend() has completed successfully.

  • init_error – If runtime.init_backend() raised, the exception; otherwise None.

  • tearing_down – In-progress marker. Set to True while destroy_sampling_input is tearing down this backend’s runtime. A second-caller of init_sampling_backend that blocks on lock and observes this flag raises RuntimeError so it does not reuse a half-shutdown runtime. Present-participle name: this is “in progress” state, not “done” state. Asymmetric with ChannelState.tombstoned by design — the registry entry is removed once teardown completes, so no terminal flag is needed; after removal nothing can find this object.

active_channels: set[int][source]#
backend_id: int[source]#
backend_key: str[source]#
init_complete: bool = False[source]#
init_error: BaseException | None = None[source]#
lock: threading.RLock[source]#
runtime: gigl.distributed.graph_store.shared_dist_sampling_producer.SharedDistSamplingBackend[source]#
tearing_down: bool = False[source]#
gigl.distributed.graph_store.dist_server.get_server()[source]#

Get the DistServer instance on the current process.

Return type:

DistServer

gigl.distributed.graph_store.dist_server.init_server(num_servers, server_rank, dataset, master_addr, master_port, num_clients=0, num_rpc_threads=16, request_timeout=180, server_group_name=None, is_dynamic=False)[source]#

Initialize the current process as a server and establish connections with all other servers and clients. Note that this method should be called only in the server-client distribution mode.

Parameters:
  • num_servers (int) – Number of processes participating in the server group.

  • server_rank (int) – Rank of the current process withing the server group (it should be a number between 0 and num_servers-1).

  • dataset (DistDataset) – The DistDataset object of a partition of graph data and feature data, along with distributed partition book info.

  • master_addr (str) – The master TCP address for RPC connection between all servers and clients, the value of this parameter should be same for all servers and clients.

  • master_port (int) – The master TCP port for RPC connection between all servers and clients, the value of this parameter should be same for all servers and clients.

  • num_clients (int) – Number of processes participating in the client group. if is_dynamic is True, this parameter will be ignored.

  • num_rpc_threads (int) – The number of RPC worker threads used for the current server to respond remote requests. (Default: 16).

  • request_timeout (int) – The max timeout seconds for remote requests, otherwise an exception will be raised. (Default: 16).

  • server_group_name (str) – A unique name of the server group that current process belongs to. If set to None, a default name will be used. (Default: None).

  • is_dynamic (bool) – Whether the world size is dynamic. (Default: False).

Return type:

None

gigl.distributed.graph_store.dist_server.wait_and_shutdown_server()[source]#

Block until all clients have shut down, then shut down the server on the current process and destroy all RPC connections.

Best-effort cluster liveness: if DistServer.shutdown raises (e.g. because a client crashed leaving state behind) we capture the exception, run barrier() + shutdown_rpc() so healthy storage peers do not hang on the barrier, then re-raise so the orchestrator sees a non-zero exit on the failing storage process.

Step 3 of the three-phase teardown described in the module docstring.

Return type:

None

gigl.distributed.graph_store.dist_server.FETCH_SLOW_LOG_SECS = 1.0[source]#
gigl.distributed.graph_store.dist_server.R[source]#
gigl.distributed.graph_store.dist_server.SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0[source]#

Interval (in seconds) to check exit status of server.

gigl.distributed.graph_store.dist_server.logger[source]#