gigl.distributed.graph_store#

Public API for GiGL’s graph-store deployment mode.

Graph-store mode separates storage and compute clusters: storage nodes build and serve a partitioned dataset, while compute nodes connect over RPC via a RemoteDistDataset.

This module is the stable import surface for that workflow; helpers, RPC utilities, and server lifecycle internals remain in private submodules.

Submodules#

Classes#

GraphStoreInfo

Information about a graph store cluster.

RemoteDistDataset

Represents a dataset that is stored on a different storage cluster.

Functions#

build_storage_dataset(task_config_uri, ...[, ...])

Build a DistDataset for a storage node from a task config.

get_graph_store_info()

Get the information about the graph store cluster.

init_compute_process(local_rank, cluster_info[, ...])

Initializes distributed setup for a compute node in a Graph Store cluster.

run_storage_server(storage_rank, cluster_info, ...[, ...])

Spawn sequential storage-server sessions as subprocesses.

shutdown_compute_process()

Shut down the compute side of a Graph Store cluster.

Package Contents#

class gigl.distributed.graph_store.GraphStoreInfo[source]#

Information about a graph store cluster.

cluster_master_ip: str#
cluster_master_port: int#
compute_cluster_master_ip: str#
compute_cluster_master_port: int#
property compute_cluster_world_size: int#
Return type:

int

property compute_node_rank: int#

Get the rank of the compute node in the compute cluster.

Raises:

ValueError – If the node is not in the compute cluster.

Return type:

int

property num_cluster_nodes: int#
Return type:

int

num_compute_nodes: int#
num_processes_per_compute: int#
num_storage_nodes: int#
readiness_uri: gigl.common.Uri#
rpc_master_port: int#
rpc_wait_port: int#
storage_cluster_master_ip: str#
storage_cluster_master_port: int#
property storage_node_rank: int#

Get the rank of the storage node in the storage cluster.

Raises:

ValueError – If the node is not in the storage cluster.

Return type:

int

class gigl.distributed.graph_store.RemoteDistDataset(cluster_info, local_rank)[source]#

Represents a dataset that is stored on a different storage cluster. Must be used in the GiGL graph-store distributed setup.

This class must be used on the compute (client) side of the graph-store distributed setup.

Parameters:
  • cluster_info (GraphStoreInfo) – The cluster information.

  • local_rank (int) – The local rank of the process on the compute node.

fetch_ablp_input(split, rank=None, world_size=None, anchor_node_type=None, supervision_edge_type=None)[source]#

Fetch ABLP (Anchor Based Link Prediction) input from the storage nodes.

The returned dict maps storage rank to an ABLPInputNodes dataclass for that storage node. If (rank, world_size) is provided, the input will be sharded across the compute nodes using contiguous server assignments. If both are None, the input will be returned unsharded for all storage nodes.

The ABLPInputNodes dataclass carries explicit node type information and keys the label tensors by their label EdgeType, making it unambiguous which node types the positive/negative labels correspond to.

Parameters:
  • split (Literal['train', 'val', 'test']) – The split to get the input for.

  • rank (Optional[int]) – The compute rank requesting data. When None (together with world_size), all data is returned unsharded from all storage nodes.

  • world_size (Optional[int]) – The total number of compute processes. When None (together with rank), all data is returned unsharded from all storage nodes.

  • anchor_node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – The type of the anchor nodes to retrieve. Must be provided for heterogeneous graphs. Must be None for labeled homogeneous graphs.

  • supervision_edge_type (Optional[gigl.src.common.types.graph_data.EdgeType]) – The edge type for supervision. Must be provided for heterogeneous graphs. Must be None for labeled homogeneous graphs.

Returns:

  • anchor_node_type: The node type of the anchor nodes, or DEFAULT_HOMOGENEOUS_NODE_TYPE for labeled homogeneous.

  • anchor_nodes: 1D tensor of anchor node IDs for the split.

  • positive_labels: Dict mapping positive label EdgeType to a 2D tensor [N, M].

  • negative_labels: Optional dict mapping negative label EdgeType to a 2D tensor [N, M].

Return type:

A dict mapping storage rank to an ABLPInputNodes containing

Raises:

ValueError – If only one of rank or world_size is provided.

Example

Suppose we have 2 storage nodes and 2 compute nodes. Storage rank 0 has anchor nodes [0, 1, 2] (train), storage rank 1 has anchor nodes [3, 4, 5] (train), with positive/negative labels for link prediction.

Shard training ABLP input across 2 compute nodes (contiguous — each rank gets entire servers):

>>> dataset.fetch_ablp_input(split="train", rank=0, world_size=2)
{
    0: ABLPInputNodes(
        anchor_nodes=tensor([0, 1, 2]),
        labels={...},
    ),
    1: ABLPInputNodes(
        anchor_nodes=tensor([]),
        labels={...},
    ),
}
>>> dataset.fetch_ablp_input(split="train", rank=1, world_size=2)
{
    0: ABLPInputNodes(
        anchor_nodes=tensor([]),
        labels={...},
    ),
    1: ABLPInputNodes(
        anchor_nodes=tensor([3, 4, 5]),
        labels={...},
    ),
}

With 3 storage nodes and 2 compute nodes, server 1 is fractionally split. Storage rank 0 has anchors [0, 1], rank 1 has [2, 3], rank 2 has [4, 5]:

>>> dataset.fetch_ablp_input(split="train", rank=0, world_size=2)
{
    0: ABLPInputNodes(
        anchor_nodes=tensor([0, 1]),
        labels={...},
    ),
    1: ABLPInputNodes(
        anchor_nodes=tensor([2]),    # First half of storage 1
        labels={...},
    ),
    2: ABLPInputNodes(
        anchor_nodes=tensor([]),     # Nothing from storage 2
        labels={...},
    ),
}
fetch_edge_dir()[source]#

Fetch the edge direction from the registered dataset.

Returns:

The edge direction.

Return type:

Union[str, Literal[‘in’, ‘out’]]

fetch_edge_feature_info()[source]#

Fetch edge feature information from the registered dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping EdgeType to FeatureInfo for heterogeneous graphs

  • None if no edge features are available

Return type:

Edge feature information, which can be

fetch_edge_partition_book(edge_type=None)[source]#

Fetches the partition book for the specified edge type.

Parameters:

edge_type (Optional[gigl.src.common.types.graph_data.EdgeType]) – The edge type to look up. Must be None for homogeneous datasets and non-None for heterogeneous ones.

Returns:

The partition book for the requested edge type, or None if no partition book is available.

Return type:

Optional[graphlearn_torch.partition.PartitionBook]

fetch_edge_types()[source]#

Fetch the edge types from the registered dataset.

Returns:

The edge types in the dataset, None if the dataset is homogeneous.

Return type:

Optional[list[gigl.src.common.types.graph_data.EdgeType]]

fetch_edge_weights_registered()[source]#

Fetch whether edge weights were registered in the remote dataset.

Returns:

True if edge weights were registered via DistPartitioner.register_edge_weights().

Return type:

bool

fetch_free_ports_on_storage_cluster(num_ports)[source]#

Get free ports from the storage master node.

This must be used with a torch.distributed process group initialized, for the entire training cluster.

All compute ranks will receive the same free ports.

Parameters:

num_ports (int) – Number of free ports to get.

Returns:

A list of free port numbers on the storage master node.

Return type:

list[int]

fetch_node_feature_info()[source]#

Fetch node feature information from the registered dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping NodeType to FeatureInfo for heterogeneous graphs

  • None if no node features are available

Return type:

Node feature information, which can be

fetch_node_ids(rank=None, world_size=None, split=None, node_type=None)[source]#

Fetch node ids from the storage nodes for the current compute node (machine).

The returned dict maps storage rank to the node ids stored on that storage node, filtered and sharded according to the provided arguments.

Storage servers are assigned to compute nodes in contiguous blocks. Each compute node fetches all data from its assigned server(s) and receives empty tensors for unassigned ones. When both rank and world_size are None, all data is returned unsharded from every storage server.

Parameters:
  • rank (Optional[int]) – The compute rank requesting data. When None (together with world_size), all data is returned unsharded from all storage nodes.

  • world_size (Optional[int]) – The total number of compute processes. When None (together with rank), all data is returned unsharded from all storage nodes.

  • split (Optional[Literal['train', 'val', 'test']]) – The split of the dataset to get node ids from. If provided, the dataset must have train_node_ids, val_node_ids, and test_node_ids properties.

  • node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – The type of nodes to get. Must be provided for heterogeneous datasets. Must be None for labeled homogeneous graphs.

Raises:

ValueError – If only one of rank or world_size is provided.

Returns:

A dict mapping storage rank to node ids.

Return type:

dict[int, torch.Tensor]

Example

Suppose we have 2 storage nodes and 2 compute nodes, with 16 total nodes. Nodes are partitioned across storage nodes, with splits defined as:

Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7]
    train=[0, 1, 2, 3], val=[4, 5], test=[6, 7]
Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15]
    train=[8, 9, 10, 11], val=[12, 13], test=[14, 15]

Get all nodes (no split filtering, no sharding):

>>> dataset.fetch_node_ids()
{
    0: tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    1: tensor([8, 9, 10, 11, 12, 13, 14, 15]),
}

Shard training nodes across 2 compute nodes (contiguous — each rank gets entire servers):

>>> dataset.fetch_node_ids(rank=0, world_size=2, split="train")
{
    0: tensor([0, 1, 2, 3]),  # All training nodes from storage 0
    1: tensor([]),             # Nothing from storage 1
}
>>> dataset.fetch_node_ids(rank=1, world_size=2, split="train")
{
    0: tensor([]),             # Nothing from storage 0
    1: tensor([8, 9, 10, 11]), # All training nodes from storage 1
}

With 3 storage nodes and 2 compute nodes, server 1 is fractionally split:

>>> dataset.fetch_node_ids(rank=0, world_size=2, split="train")
{
    0: tensor([0, 1, 2, 3]),  # All of storage 0
    1: tensor([8, 9]),         # First half of storage 1
    2: tensor([]),             # Nothing from storage 2
}

Note

When split=None, all nodes are queryable. This means nodes from any split (train, val, or test) may be returned. This is useful when you need to sample neighbors during inference, as neighbor nodes may belong to any split.

fetch_node_partition_book(node_type=None)[source]#

Fetches the partition book for the specified node type.

Parameters:

node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – The node type to look up. Must be None for homogeneous datasets and non-None for heterogeneous ones.

Returns:

The partition book for the requested node type, or None if no partition book is available.

Return type:

Optional[graphlearn_torch.partition.PartitionBook]

fetch_node_types()[source]#

Fetch the node types from the registered dataset.

Returns:

The node types in the dataset, None if the dataset is homogeneous.

Return type:

Optional[list[gigl.src.common.types.graph_data.NodeType]]

property cluster_info: gigl.env.distributed.GraphStoreInfo#
Return type:

gigl.env.distributed.GraphStoreInfo

gigl.distributed.graph_store.build_storage_dataset(task_config_uri, sample_edge_direction, tf_record_uri_pattern='.*-of-.*\\.tfrecord(\\.gz)?$', splitter=None, should_load_tensors_in_parallel=True, ssl_positive_label_percentage=None, max_labels_per_anchor_node=None)[source]#

Build a DistDataset for a storage node from a task config.

Loads the GBML config from task_config_uri, translates the protobuf metadata into SerializedGraphMetadata, and delegates to build_dataset() with DistRangePartitioner.

This should be called once per storage node (machine). A torch.distributed process group must already be initialised among all storage nodes before calling this function so that the dataset can be partitioned correctly.

Parameters:
  • task_config_uri (gigl.common.Uri) – URI pointing to a frozen GbmlConfig protobuf.

  • sample_edge_direction (Literal['in', 'out']) – Direction for edge sampling ("in" or "out").

  • tf_record_uri_pattern (str) – Regex pattern to match TFRecord file URIs.

  • splitter (Optional[Union[gigl.utils.data_splitters.DistNodeAnchorLinkSplitter, gigl.utils.data_splitters.DistNodeSplitter]]) – Optional splitter for node-anchor-link or node splitting. If None, the dataset will not be split.

  • should_load_tensors_in_parallel (bool) – Whether to load TFRecord tensors in parallel.

  • ssl_positive_label_percentage (Optional[float]) – Fraction of edges to select as self-supervised positive labels. Must be None when supervised edge labels are already provided. For example, 0.1 selects 10 % of edges.

  • max_labels_per_anchor_node (Optional[int]) – Optional cap for how many labels to materialize per anchor node when the storage server serves ABLP input.

Returns:

A partitioned DistDataset ready to be served.

Return type:

gigl.distributed.dist_dataset.DistDataset

gigl.distributed.graph_store.get_graph_store_info()[source]#

Get the information about the graph store cluster.

MUST be called with a torch.distributed process group initialized, for the entire training cluster. E.g. the process group must include both the compute and storage nodes.

This function should only be called on clusters that are setup by GiGL. E.g. when GiGLResourceConfig.trainer_resource_config.vertex_ai_graph_store_trainer_config is set.

Returns:

The information about the graph store cluster.

Return type:

GraphStoreInfo

Raises:
  • ValueError – If a torch distributed environment is not initialized.

  • ValueError – If not running running in a supported environment.

gigl.distributed.graph_store.init_compute_process(local_rank, cluster_info, compute_world_backend=None, rpc_timeout=300)[source]#

Initializes distributed setup for a compute node in a Graph Store cluster.

Should be called once per compute process (e.g. one per process per compute node, once per cluster_info.compute_cluster_world_size)

Parameters:
  • local_rank (int) – The local (process) rank on the compute node.

  • cluster_info (GraphStoreInfo) – The cluster information.

  • compute_world_backend (Optional[str]) – The backend for the compute Torch Distributed process group.

  • rpc_timeout (int) – The max timeout in seconds for remote RPC requests.

Raises:

ValueError – If the process group is already initialized.

Return type:

None

gigl.distributed.graph_store.run_storage_server(storage_rank, cluster_info, dataset, num_server_sessions, timeout_seconds=None, num_rpc_threads=16, rpc_timeout=None)[source]#

Spawn sequential storage-server sessions as subprocesses.

Each server session requires its own spawned process because you cannot re-connect to the same GLT server process after it has been joined. This function loops over num_server_sessions, spawning _run_storage_server_session() as a subprocess each time and joining it before starting the next.

Parameters:
  • storage_rank (int) – Rank of this storage node in the storage cluster.

  • cluster_info (gigl.env.distributed.GraphStoreInfo) – Cluster topology information.

  • dataset (gigl.distributed.dist_dataset.DistDataset) – The DistDataset to serve.

  • num_server_sessions (int) – Number of sequential server sessions to run (typically one per inference node type).

  • timeout_seconds (Optional[float]) – Timeout for joining each server subprocess. None waits indefinitely.

  • num_rpc_threads (int) – The number of RPC threads to use for the server. This is the maximum number of concurrent RPC requests that the server can handle. Should be set to the maximum number of concurrent RPCs a server must handle, in practice, the compute world size is an upper bound.

  • rpc_timeout (Optional[int]) – The max timeout in seconds for remote RPC requests. If None, uses the init_server default of 180 seconds. If there are long running RPCs (e.g. producer creation), and they timeout, then this parameter should be increased to avoid timeout errors.

Return type:

None

gigl.distributed.graph_store.shutdown_compute_process()[source]#

Shut down the compute side of a Graph Store cluster.

Step 2 of the three-phase teardown described in gigl.distributed.graph_store.dist_server — call this after every DistLoader.shutdown() on this rank has returned, so all server-side channels are already destroyed.

Calls glt.distributed.shutdown_client and torch.distributed.destroy_process_group exactly once per compute process.

Should be called once per compute process (e.g. one per process per compute node, once per cluster_info.compute_cluster_world_size).

Return type:

None