gigl.distributed.graph_store.storage_utils#

Utils for operating on a dataset remotely.

These are intended to be used in the context of a server-client architecture, and with graphlearn_torch.distributed.request_server.

register_dataset must be called once per process in the server.

And then the client can do something like:

>>> edge_feature_info = graphlearn_torch.distributed.request_server(
>>>    server_rank,
>>>    gigl.distributed.graph_store.storage_utils.get_edge_feature_info,
>>> )

NOTE: Ideally these would be exposed via DistServer [1] so we could call them directly. TOOD(kmonte): If we ever fork GLT, we should look into expanding DistServer instead.

[1]: alibaba/graphlearn-for-pytorch

Attributes#

Functions#

get_ablp_input(split[, rank, world_size, node_type, ...])

Get the ABLP (Anchor Based Link Prediction) input for a specific rank in distributed processing.

get_edge_dir()

Get the edge direction from the registered dataset.

get_edge_feature_info()

Get edge feature information from the registered dataset.

get_edge_types()

Get the edge types from the registered dataset.

get_node_feature_info()

Get node feature information from the registered dataset.

get_node_ids([rank, world_size, split, node_type])

Get the node ids from the registered dataset.

register_dataset(dataset)

Register a dataset for remote access.

Module Contents#

gigl.distributed.graph_store.storage_utils.get_ablp_input(split, rank=None, world_size=None, node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE)[source]#

Get the ABLP (Anchor Based Link Prediction) input for a specific rank in distributed processing.

Note: rank and world_size here are for the process group we’re fetching for, not the process group we’re fetching from. e.g. if our compute cluster is of world size 4, and we have 2 storage nodes, then the world size this gets called with is 4, not 2.

Parameters:
  • split (Union[Literal['train', 'val', 'test'], str]) – The split to get the training input for.

  • rank (Optional[int]) – The rank of the process requesting the training input. Defaults to None, in which case all nodes are returned. Must be provided if world_size is provided.

  • world_size (Optional[int]) – The total number of processes in the distributed setup. Defaults to None, in which case all nodes are returned. Must be provided if rank is provided.

  • node_type (gigl.src.common.types.graph_data.NodeType) – The type of nodes to retrieve. Defaults to the default homogeneous node type.

  • supervision_edge_type (gigl.src.common.types.graph_data.EdgeType) – The edge type to use for the supervision. Defaults to the default homogeneous edge type.

Returns:

A tuple containing the anchor nodes for the rank, the positive labels, and the negative labels. The positive labels are of shape [N, M], where N is the number of anchor nodes and M is the number of positive labels. The negative labels are of shape [N, M], where N is the number of anchor nodes and M is the number of negative labels. The negative labels may be None if no negative labels are available.

Raises:

ValueError – If no dataset has been registered or if the split is invalid.

Return type:

tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]

gigl.distributed.graph_store.storage_utils.get_edge_dir()[source]#

Get the edge direction from the registered dataset.

Returns:

The edge direction.

Return type:

Literal[‘in’, ‘out’]

gigl.distributed.graph_store.storage_utils.get_edge_feature_info()[source]#

Get edge feature information from the registered dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping EdgeType to FeatureInfo for heterogeneous graphs

  • None if no edge features are available

Return type:

Edge feature information, which can be

Raises:

ValueError – If no dataset has been registered.

gigl.distributed.graph_store.storage_utils.get_edge_types()[source]#

Get the edge types from the registered dataset.

Returns:

The edge types in the dataset, None if the dataset is homogeneous.

Return type:

Optional[list[gigl.src.common.types.graph_data.EdgeType]]

gigl.distributed.graph_store.storage_utils.get_node_feature_info()[source]#

Get node feature information from the registered dataset.

Returns:

  • A single FeatureInfo object for homogeneous graphs

  • A dict mapping NodeType to FeatureInfo for heterogeneous graphs

  • None if no node features are available

Return type:

Node feature information, which can be

Raises:

ValueError – If no dataset has been registered.

gigl.distributed.graph_store.storage_utils.get_node_ids(rank=None, world_size=None, split=None, node_type=None)[source]#

Get the node ids from the registered dataset.

Parameters:
  • rank (Optional[int]) – The rank of the process requesting node ids. Must be provided if world_size is provided.

  • world_size (Optional[int]) – The total number of processes in the distributed setup. Must be provided if rank is provided.

  • split (Optional[Literal["train", "val", "test"]]) – The split of the dataset to get node ids from. If provided, the dataset must have train_node_ids, val_node_ids, and test_node_ids properties.

  • node_type (Optional[NodeType]) – The type of nodes to get node ids for. Must be provided if the dataset is heterogeneous.

Returns:

The node ids.

Raises:

ValueError

  • If no dataset has been registered

  • If the rank and world_size are not provided together

  • If the split is invalid

  • If the node ids are not a torch.Tensor or a dict[NodeType, torch.Tensor]

  • If the node type is provided for a homogeneous dataset

  • If the node ids are not a dict[NodeType, torch.Tensor] when no node type is provided

Return type:

torch.Tensor

Examples

Suppose the dataset has 100 nodes total: train=[0..59], val=[60..79], test=[80..99].

Get all node ids (no split filtering):

>>> get_node_ids()
tensor([0, 1, 2, ..., 99])  # All 100 nodes

Get only training nodes:

>>> get_node_ids(split="train")
tensor([0, 1, 2, ..., 59])  # 60 training nodes

Shard all nodes across 4 processes (each gets ~25 nodes):

>>> get_node_ids(rank=0, world_size=4)
tensor([0, 1, 2, ..., 24])  # First 25 of all 100 nodes

Shard training nodes across 4 processes (each gets ~15 nodes):

>>> get_node_ids(rank=0, world_size=4, split="train")
tensor([0, 1, 2, ..., 14])  # First 15 of the 60 training nodes

Note: When split=None, all nodes are queryable. This means nodes from any split (train, val, or test) may be returned. This is useful when you need to sample neighbors during inference, as neighbor nodes may belong to any split.

gigl.distributed.graph_store.storage_utils.register_dataset(dataset)[source]#

Register a dataset for remote access.

This function must be called once per process in the server before any remote dataset operations can be performed.

Parameters:

dataset (gigl.distributed.dist_dataset.DistDataset) – The distributed dataset to register.

Raises:

ValueError – If a dataset has already been registered.

Return type:

None

gigl.distributed.graph_store.storage_utils.logger[source]#