Source code for gigl.distributed.graph_store.messages
"""RPC request messages for graph-store fetch operations."""
from dataclasses import dataclass
from typing import Literal, Optional, Union
from gigl.src.common.types.graph_data import EdgeType, NodeType
@dataclass(frozen=True)
[docs]
class FetchNodesRequest:
"""Request for fetching node IDs from a storage server.
Args:
rank: The rank of the process requesting node ids.
Must be provided together with ``world_size``.
world_size: The total number of processes in the distributed setup.
Must be provided together with ``rank``.
split: The split of the dataset to get node ids from.
node_type: The type of nodes to get node ids for.
Examples:
Fetch all nodes without sharding:
>>> FetchNodesRequest()
Fetch training nodes for rank 0 of 4:
>>> FetchNodesRequest(rank=0, world_size=4, split="train")
Fetch nodes of a specific type:
>>> FetchNodesRequest(node_type="user")
"""
[docs]
rank: Optional[int] = None
[docs]
world_size: Optional[int] = None
[docs]
split: Optional[Union[Literal["train", "val", "test"], str]] = None
[docs]
node_type: Optional[NodeType] = None
[docs]
def validate(self) -> None:
"""Validate that the request has consistent rank/world_size.
Raises:
ValueError: If only one of ``rank`` or ``world_size`` is provided.
"""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
@dataclass(frozen=True)