Source code for gigl.distributed.graph_store.messages

"""RPC request messages for graph-store fetch operations."""

from dataclasses import dataclass
from typing import Literal, Optional, Union

from gigl.src.common.types.graph_data import EdgeType, NodeType


@dataclass(frozen=True)
[docs] class FetchNodesRequest: """Request for fetching node IDs from a storage server. Args: rank: The rank of the process requesting node ids. Must be provided together with ``world_size``. world_size: The total number of processes in the distributed setup. Must be provided together with ``rank``. split: The split of the dataset to get node ids from. node_type: The type of nodes to get node ids for. Examples: Fetch all nodes without sharding: >>> FetchNodesRequest() Fetch training nodes for rank 0 of 4: >>> FetchNodesRequest(rank=0, world_size=4, split="train") Fetch nodes of a specific type: >>> FetchNodesRequest(node_type="user") """
[docs] rank: Optional[int] = None
[docs] world_size: Optional[int] = None
[docs] split: Optional[Union[Literal["train", "val", "test"], str]] = None
[docs] node_type: Optional[NodeType] = None
[docs] def validate(self) -> None: """Validate that the request has consistent rank/world_size. Raises: ValueError: If only one of ``rank`` or ``world_size`` is provided. """ if (self.rank is None) ^ (self.world_size is None): raise ValueError( "rank and world_size must be provided together. " f"Received rank={self.rank}, world_size={self.world_size}" )
@dataclass(frozen=True)
[docs] class FetchABLPInputRequest: """Request for fetching ABLP input from a storage server. Args: split: The split of the dataset to get ABLP input from. node_type: The type of anchor nodes to retrieve. supervision_edge_type: The edge type used for supervision. rank: The rank of the process requesting ABLP input. Must be provided together with ``world_size``. world_size: The total number of processes in the distributed setup. Must be provided together with ``rank``. Examples: Fetch training ABLP input without sharding: >>> FetchABLPRequest(split="train", node_type="user", supervision_edge_type=("user", "to", "item")) Fetch training ABLP input for rank 0 of 4: >>> FetchABLPRequest(split="train", node_type="user", supervision_edge_type=("user", "to", "item"), rank=0, world_size=4) """
[docs] split: Union[Literal["train", "val", "test"], str]
[docs] node_type: NodeType
[docs] supervision_edge_type: EdgeType
[docs] rank: Optional[int] = None
[docs] world_size: Optional[int] = None
[docs] def validate(self) -> None: """Validate that the request has consistent rank/world_size. Raises: ValueError: If only one of ``rank`` or ``world_size`` is provided. """ if (self.rank is None) ^ (self.world_size is None): raise ValueError( "rank and world_size must be provided together. " f"Received rank={self.rank}, world_size={self.world_size}" )