gigl.distributed.graph_store.dist_server#

GiGL implementation of GLT DistServer.

Uses GiGL’s DistSamplingProducer which supports neighbor sampling and ABLP (anchor-based link prediction) via BaseGiGLSampler subclasses (DistNeighborSampler for k-hop, DistPPRNeighborSampler for PPR).

Based on alibaba/graphlearn-for-pytorch

Attributes#

FETCH_SLOW_LOG_SECS

R

SERVER_EXIT_STATUS_CHECK_INTERVAL

Interval (in seconds) to check exit status of server.

logger

Classes#

ChannelState

Per-channel state for a registered sampling input.

DistServer

A server that supports launching remote sampling workers for

SamplingBackendState

Per-backend state for a shared sampling backend.

Functions#

get_server()

Get the DistServer instance on the current process.

init_server(num_servers, server_rank, dataset, ...[, ...])

Initialize the current process as a server and establish connections

wait_and_shutdown_server()

Block until all client have been shutdowned, and further shutdown the

Module Contents#

class gigl.distributed.graph_store.dist_server.ChannelState[source]#

Per-channel state for a registered sampling input.

Parameters:
  • backend_id – The ID of the backend this channel belongs to.

  • worker_key – The unique key identifying this compute-rank channel.

  • channel – The shared-memory channel for passing sampled messages.

  • epoch – The last epoch started on this channel.

  • lock – A reentrant lock guarding channel-level operations.

backend_id: int[source]#
channel: graphlearn_torch.channel.ShmChannel[source]#
epoch: int = -1[source]#
lock: threading.RLock[source]#
worker_key: str[source]#
class gigl.distributed.graph_store.dist_server.DistServer(dataset, log_every_n=50)[source]#

A server that supports launching remote sampling workers for training clients.

Note that this server is enabled only when the distribution mode is a server-client framework, and the graph and feature store will be partitioned and managed by all server nodes.

Parameters:
  • dataset (DistDataset) – The DistDataset object of a partition of graph data and feature data, along with distributed partition books.

  • log_every_n (int) – Log aggregated fetch_one_sampled_message timing stats after every N fetch attempts per channel.

create_sampling_producer(sampler_input, sampling_config, worker_options, sampler_options)[source]#

Create a sampling producer by delegating to the two-phase API.

Bridge method that keeps existing loaders working. Internally calls init_sampling_backend() and register_sampling_input(), returning the channel_id as the producer_id.

Parameters:
  • sampler_input (Union[graphlearn_torch.sampler.NodeSamplerInput, graphlearn_torch.sampler.EdgeSamplerInput, graphlearn_torch.sampler.RemoteSamplerInput, gigl.distributed.sampler.ABLPNodeSamplerInput]) – The input data for sampling.

  • sampling_config (graphlearn_torch.sampler.SamplingConfig) – Configuration of sampling meta info.

  • worker_options (graphlearn_torch.distributed.RemoteDistSamplingWorkerOptions) – Options for launching remote sampling workers.

  • sampler_options (gigl.distributed.sampler_options.SamplerOptions) – Controls which sampler class is instantiated.

Returns:

A unique ID (channel_id) usable as a producer_id.

Return type:

int

destroy_sampling_input(channel_id)[source]#

Destroy one registered sampling channel and maybe its backend.

If this is the last channel on the backend, the backend is shut down and removed.

Caller contract: callers must have drained the channel (i.e., observed fetch_one_sampled_message() return (None, True)) before calling destroy. This is required because fetch_one_sampled_message() holds channel_state.lock for the duration of its recv loop; a destroy issued mid-epoch will block waiting for the fetch to exit, and if the producer is stuck, destroy will block indefinitely. All in-tree callers satisfy this contract via BaseDistLoader’s teardown sequence.

Parameters:

channel_id (int) – The ID of the channel to destroy.

Return type:

None

destroy_sampling_producer(producer_id)[source]#

Destroy a sampling producer by delegating to destroy_sampling_input().

Bridge method that keeps existing loaders working.

Parameters:

producer_id (int) – The producer ID (channel_id) to destroy.

Return type:

None

exit()[source]#

Set the exit flag to True.

Return type:

bool

fetch_one_sampled_message(channel_id)[source]#

Fetch one sampled message from a registered channel.

Parameters:

channel_id (int) – The ID of the channel to fetch from.

Returns:

A tuple of (message, is_done). If is_done is True, no more messages will be produced for this epoch.

Return type:

tuple[Optional[graphlearn_torch.channel.SampleMessage], bool]

get_ablp_input(request)[source]#

Get the ABLP (Anchor Based Link Prediction) input for distributed processing.

Parameters:

request (gigl.distributed.graph_store.messages.FetchABLPInputRequest) – The ABLP fetch request, including split, node type, supervision edge type, and an optional contiguous server slice.

Returns:

A tuple containing the anchor nodes for the rank, the positive labels, and the negative labels. The positive labels are of shape [N, M], where N is the number of anchor nodes and M is the number of positive labels. The negative labels are of shape [N, M], where N is the number of anchor nodes and M is the number of negative labels. The negative labels may be None if no negative labels are available.

Raises:

ValueError – If the split is invalid.

Return type:

tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]

get_dataset_meta()[source]#

Get the meta info of the distributed dataset managed by the current server, including partition info and graph types.

Return type:

tuple[int, int, Optional[list[gigl.src.common.types.graph_data.NodeType]], Optional[list[gigl.src.common.types.graph_data.EdgeType]]]

get_edge_dir()[source]#

Get the edge direction from the dataset.

Returns:

The edge direction.

Return type:

Literal[‘in’, ‘out’]

get_edge_feature_info()[source]#

Get edge feature information from the dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping EdgeType to FeatureInfo for heterogeneous graphs

  • None if no edge features are available

Return type:

Edge feature information, which can be

get_edge_index(edge_type, layout)[source]#
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]

get_edge_partition_book(edge_type)[source]#

Gets the partition book for the specified edge type. :param edge_type: The edge type to look up. Must be None for

homogeneous datasets and non-None for heterogeneous ones.

Returns:

The partition book for the requested edge type, or None if no partition book is available.

Raises:

ValueError – If edge_type is mismatched with the dataset type.

Parameters:

edge_type (Optional[gigl.src.common.types.graph_data.EdgeType])

Return type:

Optional[graphlearn_torch.partition.PartitionBook]

get_edge_size(edge_type, layout)[source]#
Parameters:
Return type:

tuple[int, int]

get_edge_types()[source]#

Get the edge types from the dataset.

Returns:

The edge types in the dataset, None if the dataset is homogeneous.

Return type:

Optional[list[gigl.src.common.types.graph_data.EdgeType]]

get_node_feature(node_type, index)[source]#
Parameters:
  • node_type (Optional[gigl.src.common.types.graph_data.NodeType])

  • index (torch.Tensor)

Return type:

torch.Tensor

get_node_feature_info()[source]#

Get node feature information from the dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping NodeType to FeatureInfo for heterogeneous graphs

  • None if no node features are available

Return type:

Node feature information, which can be

get_node_ids(request)[source]#

Get the node ids from the dataset.

Parameters:

request (gigl.distributed.graph_store.messages.FetchNodesRequest) – The node-fetch request, including split, node type, and an optional contiguous server slice.

Returns:

The node ids.

Raises:
  • ValueError

    • If the split is invalid

    • If the node ids are not a torch.Tensor or a dict[NodeType, torch.Tensor]

    • If the node type is provided for a homogeneous dataset

    • If the node ids are not a dict[NodeType, torch.Tensor] when no node type is provided

  • Note – When split=None, all nodes are queryable. This means nodes from any

  • split (train, val, or test) may be returned. This is useful when you need

  • to sample neighbors during inference, as neighbor nodes may belong to any split.

Return type:

torch.Tensor

get_node_label(node_type, index)[source]#
Parameters:
  • node_type (Optional[gigl.src.common.types.graph_data.NodeType])

  • index (torch.Tensor)

Return type:

torch.Tensor

get_node_partition_book(node_type)[source]#

Gets the partition book for the specified node type.

Parameters:

node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – The node type to look up. Must be None for homogeneous datasets and non-None for heterogeneous ones.

Returns:

The partition book for the requested node type, or None if no partition book is available.

Raises:

ValueError – If node_type is mismatched with the dataset type.

Return type:

Optional[graphlearn_torch.partition.PartitionBook]

get_node_partition_id(node_type, index)[source]#
Parameters:
  • node_type (Optional[gigl.src.common.types.graph_data.NodeType])

  • index (torch.Tensor)

Return type:

Optional[torch.Tensor]

get_node_types()[source]#

Get the node types from the dataset.

Returns:

The node types in the dataset, None if the dataset is homogeneous.

Return type:

Optional[list[gigl.src.common.types.graph_data.NodeType]]

get_tensor_size(node_type)[source]#
Parameters:

node_type (Optional[gigl.src.common.types.graph_data.NodeType])

Return type:

torch.Size

init_sampling_backend(opts)[source]#

Create or reuse a shared sampling backend for one loader instance.

If a backend with the same backend_key already exists, blocks until the original initializer has finished, then returns its ID. If the original initializer failed, re-raises that failure so the second caller does not see a half-initialized backend.

Parameters:

opts (gigl.distributed.graph_store.messages.InitSamplingBackendRequest) – The initialization request containing the backend key, worker options, sampler options, and sampling config.

Returns:

The unique backend ID.

Raises:

RuntimeError – If a prior concurrent initialization for the same backend_key failed.

Return type:

int

register_sampling_input(opts)[source]#

Register one compute-rank input channel on an existing backend.

Parameters:

opts (gigl.distributed.graph_store.messages.RegisterBackendRequest) – The registration request containing the backend ID, worker key, sampler input, sampling config, and buffer settings.

Returns:

The unique channel ID for this input.

Return type:

int

shutdown()[source]#
Return type:

None

start_new_epoch_sampling(channel_id, epoch)[source]#

Start one new epoch on one registered channel.

No-op if this channel has already started epoch or a later epoch (idempotent — safe to call repeatedly from retries).

Parameters:
  • channel_id (int) – The ID of the channel to start the epoch on.

  • epoch (int) – The epoch number to start.

Raises:

RuntimeError – If the channel or its backend is not found.

Return type:

None

wait_for_exit()[source]#

Block until the exit flag been set to True.

Return type:

None

dataset[source]#
class gigl.distributed.graph_store.dist_server.SamplingBackendState[source]#

Per-backend state for a shared sampling backend.

Parameters:
  • backend_id – The unique ID of this backend.

  • backend_key – The key identifying this backend (e.g. "dist_neighbor_loader_0").

  • runtime – The shared sampling backend runtime.

  • active_channels – Set of channel IDs currently registered on this backend.

  • lock – A reentrant lock guarding backend-level operations.

  • init_complete – Whether runtime.init_backend() has completed successfully.

  • init_error – If runtime.init_backend() raised, the exception; otherwise None.

active_channels: set[int][source]#
backend_id: int[source]#
backend_key: str[source]#
init_complete: bool = False[source]#
init_error: BaseException | None = None[source]#
lock: threading.RLock[source]#
runtime: gigl.distributed.graph_store.shared_dist_sampling_producer.SharedDistSamplingBackend[source]#
gigl.distributed.graph_store.dist_server.get_server()[source]#

Get the DistServer instance on the current process.

Return type:

DistServer

gigl.distributed.graph_store.dist_server.init_server(num_servers, server_rank, dataset, master_addr, master_port, num_clients=0, num_rpc_threads=16, request_timeout=180, server_group_name=None, is_dynamic=False)[source]#

Initialize the current process as a server and establish connections with all other servers and clients. Note that this method should be called only in the server-client distribution mode.

Parameters:
  • num_servers (int) – Number of processes participating in the server group.

  • server_rank (int) – Rank of the current process withing the server group (it should be a number between 0 and num_servers-1).

  • dataset (DistDataset) – The DistDataset object of a partition of graph data and feature data, along with distributed partition book info.

  • master_addr (str) – The master TCP address for RPC connection between all servers and clients, the value of this parameter should be same for all servers and clients.

  • master_port (int) – The master TCP port for RPC connection between all servers and clients, the value of this parameter should be same for all servers and clients.

  • num_clients (int) – Number of processes participating in the client group. if is_dynamic is True, this parameter will be ignored.

  • num_rpc_threads (int) – The number of RPC worker threads used for the current server to respond remote requests. (Default: 16).

  • request_timeout (int) – The max timeout seconds for remote requests, otherwise an exception will be raised. (Default: 16).

  • server_group_name (str) – A unique name of the server group that current process belongs to. If set to None, a default name will be used. (Default: None).

  • is_dynamic (bool) – Whether the world size is dynamic. (Default: False).

Return type:

None

gigl.distributed.graph_store.dist_server.wait_and_shutdown_server()[source]#

Block until all client have been shutdowned, and further shutdown the server on the current process and destroy all RPC connections.

Return type:

None

gigl.distributed.graph_store.dist_server.FETCH_SLOW_LOG_SECS = 1.0[source]#
gigl.distributed.graph_store.dist_server.R[source]#
gigl.distributed.graph_store.dist_server.SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0[source]#

Interval (in seconds) to check exit status of server.

gigl.distributed.graph_store.dist_server.logger[source]#