gigl.distributed.graph_store.dist_server#

GiGL implementation of GLT DistServer.

Main change here is that we use gigl DistAblpSamplingProducer instead of GLT DistMpSamplingProducer.

Based on alibaba/graphlearn-for-pytorch

Attributes#

R

SERVER_EXIT_STATUS_CHECK_INTERVAL

Interval (in seconds) to check exit status of server.

logger

Classes#

DistServer

A server that supports launching remote sampling workers for

Functions#

get_server()

Get the DistServer instance on the current process.

init_server(num_servers, server_rank, dataset, ...[, ...])

Initialize the current process as a server and establish connections

wait_and_shutdown_server()

Block until all client have been shutdowned, and further shutdown the

Module Contents#

class gigl.distributed.graph_store.dist_server.DistServer(dataset)[source]#

A server that supports launching remote sampling workers for training clients.

Note that this server is enabled only when the distribution mode is a server-client framework, and the graph and feature store will be partitioned and managed by all server nodes.

Parameters:

dataset (DistDataset) – The DistDataset object of a partition of graph data and feature data, along with distributed patition books.

create_sampling_ablp_producer(sampler_input, sampling_config, worker_options)[source]#

Create and initialize an instance of DistABLPSamplingProducer with a group of subprocesses for distributed sampling.

Parameters:
  • sampler_input (NodeSamplerInput or EdgeSamplerInput) – The input data for sampling.

  • sampling_config (SamplingConfig) – Configuration of sampling meta info.

  • worker_options (RemoteDistSamplingWorkerOptions) – Options for launching remote sampling workers by this server.

Returns:

A unique id of created sampling producer on this server.

Return type:

int

create_sampling_producer(sampler_input, sampling_config, worker_options)[source]#

Create and initialize an instance of DistSamplingProducer with a group of subprocesses for distributed sampling.

Parameters:
  • sampler_input (NodeSamplerInput or EdgeSamplerInput) – The input data for sampling.

  • sampling_config (SamplingConfig) – Configuration of sampling meta info.

  • worker_options (RemoteDistSamplingWorkerOptions) – Options for launching remote sampling workers by this server.

Returns:

A unique id of created sampling producer on this server.

Return type:

int

destroy_sampling_producer(producer_id)[source]#

Shutdown and destroy a sampling producer managed by this server with its producer id.

Parameters:

producer_id (int)

Return type:

None

exit()[source]#

Set the exit flag to True.

Return type:

bool

fetch_one_sampled_message(producer_id)[source]#

Fetch a sampled message from the buffer of a specific sampling producer with its producer id.

Parameters:

producer_id (int)

Return type:

tuple[Optional[bytes], bool]

get_ablp_input(split, rank=None, world_size=None, node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE)[source]#

Get the ABLP (Anchor Based Link Prediction) input for a specific rank in distributed processing.

Note: rank and world_size here are for the process group we’re fetching for, not the process group we’re fetching from. e.g. if our compute cluster is of world size 4, and we have 2 storage nodes, then the world size this gets called with is 4, not 2.

Parameters:
  • split (Union[Literal['train', 'val', 'test'], str]) – The split to get the training input for.

  • rank (Optional[int]) – The rank of the process requesting the training input. Defaults to None, in which case all nodes are returned. Must be provided if world_size is provided.

  • world_size (Optional[int]) – The total number of processes in the distributed setup. Defaults to None, in which case all nodes are returned. Must be provided if rank is provided.

  • node_type (gigl.src.common.types.graph_data.NodeType) – The type of nodes to retrieve. Defaults to the default homogeneous node type.

  • supervision_edge_type (gigl.src.common.types.graph_data.EdgeType) – The edge type to use for the supervision. Defaults to the default homogeneous edge type.

Returns:

A tuple containing the anchor nodes for the rank, the positive labels, and the negative labels. The positive labels are of shape [N, M], where N is the number of anchor nodes and M is the number of positive labels. The negative labels are of shape [N, M], where N is the number of anchor nodes and M is the number of negative labels. The negative labels may be None if no negative labels are available.

Raises:

ValueError – If the split is invalid.

Return type:

tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]

get_dataset_meta()[source]#

Get the meta info of the distributed dataset managed by the current server, including partition info and graph types.

Return type:

tuple[int, int, Optional[list[gigl.src.common.types.graph_data.NodeType]], Optional[list[gigl.src.common.types.graph_data.EdgeType]]]

get_edge_dir()[source]#

Get the edge direction from the dataset.

Returns:

The edge direction.

Return type:

Literal[‘in’, ‘out’]

get_edge_feature_info()[source]#

Get edge feature information from the dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping EdgeType to FeatureInfo for heterogeneous graphs

  • None if no edge features are available

Return type:

Edge feature information, which can be

get_edge_index(edge_type, layout)[source]#
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]

get_edge_partition_book(edge_type)[source]#

Gets the partition book for the specified edge type. :param edge_type: The edge type to look up. Must be None for

homogeneous datasets and non-None for heterogeneous ones.

Returns:

The partition book for the requested edge type, or None if no partition book is available.

Raises:

ValueError – If edge_type is mismatched with the dataset type.

Parameters:

edge_type (Optional[gigl.src.common.types.graph_data.EdgeType])

Return type:

Optional[graphlearn_torch.partition.PartitionBook]

get_edge_size(edge_type, layout)[source]#
Parameters:
Return type:

tuple[int, int]

get_edge_types()[source]#

Get the edge types from the dataset.

Returns:

The edge types in the dataset, None if the dataset is homogeneous.

Return type:

Optional[list[gigl.src.common.types.graph_data.EdgeType]]

get_node_feature(node_type, index)[source]#
Parameters:
  • node_type (Optional[gigl.src.common.types.graph_data.NodeType])

  • index (torch.Tensor)

Return type:

torch.Tensor

get_node_feature_info()[source]#

Get node feature information from the dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping NodeType to FeatureInfo for heterogeneous graphs

  • None if no node features are available

Return type:

Node feature information, which can be

get_node_ids(rank=None, world_size=None, split=None, node_type=None)[source]#

Get the node ids from the dataset.

Parameters:
  • rank (Optional[int]) – The rank of the process requesting node ids. Must be provided if world_size is provided.

  • world_size (Optional[int]) – The total number of processes in the distributed setup. Must be provided if rank is provided.

  • split (Optional[Union[Literal['train', 'val', 'test'], str]]) – The split of the dataset to get node ids from. If provided, the dataset must have train_node_ids, val_node_ids, and test_node_ids properties.

  • node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – The type of nodes to get node ids for. Must be provided if the dataset is heterogeneous.

Returns:

The node ids.

Raises:

ValueError

  • If the rank and world_size are not provided together

  • If the split is invalid

  • If the node ids are not a torch.Tensor or a dict[NodeType, torch.Tensor]

  • If the node type is provided for a homogeneous dataset

  • If the node ids are not a dict[NodeType, torch.Tensor] when no node type is provided

Return type:

torch.Tensor

Examples

Suppose the dataset has 100 nodes total: train=[0..59], val=[60..79], test=[80..99].

Get all node ids (no split filtering):

>>> server.get_node_ids()
tensor([0, 1, 2, ..., 99])  # All 100 nodes

Get only training nodes:

>>> server.get_node_ids(split="train")
tensor([0, 1, 2, ..., 59])  # 60 training nodes

Shard all nodes across 4 processes (each gets ~25 nodes):

>>> server.get_node_ids(rank=0, world_size=4)
tensor([0, 1, 2, ..., 24])  # First 25 of all 100 nodes

Shard training nodes across 4 processes (each gets ~15 nodes):

>>> server.get_node_ids(rank=0, world_size=4, split="train")
tensor([0, 1, 2, ..., 14])  # First 15 of the 60 training nodes

Note: When split=None, all nodes are queryable. This means nodes from any split (train, val, or test) may be returned. This is useful when you need to sample neighbors during inference, as neighbor nodes may belong to any split.

get_node_label(node_type, index)[source]#
Parameters:
  • node_type (Optional[gigl.src.common.types.graph_data.NodeType])

  • index (torch.Tensor)

Return type:

torch.Tensor

get_node_partition_book(node_type)[source]#

Gets the partition book for the specified node type.

Parameters:

node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – The node type to look up. Must be None for homogeneous datasets and non-None for heterogeneous ones.

Returns:

The partition book for the requested node type, or None if no partition book is available.

Raises:

ValueError – If node_type is mismatched with the dataset type.

Return type:

Optional[graphlearn_torch.partition.PartitionBook]

get_node_partition_id(node_type, index)[source]#
Parameters:
  • node_type (Optional[gigl.src.common.types.graph_data.NodeType])

  • index (torch.Tensor)

Return type:

Optional[torch.Tensor]

get_node_types()[source]#

Get the node types from the dataset.

Returns:

The node types in the dataset, None if the dataset is homogeneous.

Return type:

Optional[list[gigl.src.common.types.graph_data.NodeType]]

get_tensor_size(node_type)[source]#
Parameters:

node_type (Optional[gigl.src.common.types.graph_data.NodeType])

Return type:

torch.Size

shutdown()[source]#
Return type:

None

start_new_epoch_sampling(producer_id, epoch)[source]#

Start a new epoch sampling tasks for a specific sampling producer with its producer id.

Parameters:
  • producer_id (int)

  • epoch (int)

Return type:

None

wait_for_exit()[source]#

Block until the exit flag been set to True.

Return type:

None

dataset[source]#
gigl.distributed.graph_store.dist_server.get_server()[source]#

Get the DistServer instance on the current process.

Return type:

DistServer

gigl.distributed.graph_store.dist_server.init_server(num_servers, server_rank, dataset, master_addr, master_port, num_clients=0, num_rpc_threads=16, request_timeout=180, server_group_name=None, is_dynamic=False)[source]#

Initialize the current process as a server and establish connections with all other servers and clients. Note that this method should be called only in the server-client distribution mode.

Parameters:
  • num_servers (int) – Number of processes participating in the server group.

  • server_rank (int) – Rank of the current process withing the server group (it should be a number between 0 and num_servers-1).

  • dataset (DistDataset) – The DistDataset object of a partition of graph data and feature data, along with distributed patition book info.

  • master_addr (str) – The master TCP address for RPC connection between all servers and clients, the value of this parameter should be same for all servers and clients.

  • master_port (int) – The master TCP port for RPC connection between all servers and clients, the value of this parameter should be same for all servers and clients.

  • num_clients (int) – Number of processes participating in the client group. if is_dynamic is True, this parameter will be ignored.

  • num_rpc_threads (int) – The number of RPC worker threads used for the current server to respond remote requests. (Default: 16).

  • request_timeout (int) – The max timeout seconds for remote requests, otherwise an exception will be raised. (Default: 16).

  • server_group_name (str) – A unique name of the server group that current process belongs to. If set to None, a default name will be used. (Default: None).

  • is_dynamic (bool) – Whether the world size is dynamic. (Default: False).

Return type:

None

gigl.distributed.graph_store.dist_server.wait_and_shutdown_server()[source]#

Block until all client have been shutdowned, and further shutdown the server on the current process and destroy all RPC connections.

Return type:

None

gigl.distributed.graph_store.dist_server.R[source]#
gigl.distributed.graph_store.dist_server.SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0[source]#

Interval (in seconds) to check exit status of server.

gigl.distributed.graph_store.dist_server.logger[source]#