Source code for gigl.distributed.graph_store.messages
"""RPC request messages for graph-store operations."""
from dataclasses import dataclass
from typing import Literal, Optional, Union
from graphlearn_torch.distributed import RemoteDistSamplingWorkerOptions
from graphlearn_torch.sampler import (
EdgeSamplerInput,
NodeSamplerInput,
RemoteSamplerInput,
SamplingConfig,
)
from gigl.distributed.graph_store.sharding import ServerSlice
from gigl.distributed.sampler import ABLPNodeSamplerInput
from gigl.distributed.sampler_options import SamplerOptions
from gigl.src.common.types.graph_data import EdgeType, NodeType
@dataclass(frozen=True)
[docs]
class InitSamplingBackendRequest:
"""Request to initialize a shared sampling backend on a storage server.
Args:
backend_key: A unique key identifying the backend (e.g. ``"dist_neighbor_loader_0"``).
worker_options: Options for launching remote sampling workers.
sampler_options: Controls which sampler class is instantiated.
sampling_config: Configuration for sampling behavior.
"""
[docs]
worker_options: RemoteDistSamplingWorkerOptions
[docs]
sampler_options: SamplerOptions
[docs]
sampling_config: SamplingConfig
@dataclass(frozen=True)
[docs]
class RegisterBackendRequest:
"""Request to register one compute-rank input channel on a backend.
Args:
backend_id: The ID of the backend to register on.
worker_key: A unique key identifying this compute-rank channel.
sampler_input: The input data for sampling.
sampling_config: Configuration for sampling behavior.
buffer_capacity: Number of shared-memory buffer slots.
buffer_size: Size of each buffer slot (int bytes or string like ``"1MB"``).
"""
[docs]
sampling_config: SamplingConfig
[docs]
buffer_size: Union[int, str]
@dataclass(frozen=True)
[docs]
class FetchNodesRequest:
"""Request for fetching node IDs from a storage server.
Args:
split: The split of the dataset to get node ids from.
node_type: The type of nodes to get node ids for.
server_slice: An optional :class:`~gigl.distributed.graph_store.sharding.ServerSlice`
describing the fraction of this server's data to return.
When ``None``, all of the server's data is returned.
Examples:
Fetch all nodes without sharding:
>>> FetchNodesRequest()
Fetch nodes of a specific type:
>>> FetchNodesRequest(node_type="user")
Fetch the first half of a server's training nodes:
>>> FetchNodesRequest(split="train", server_slice=ServerSlice(0, 0, 1, 2))
"""
[docs]
split: Optional[Union[Literal["train", "val", "test"], str]] = None
[docs]
node_type: Optional[NodeType] = None
[docs]
server_slice: Optional[ServerSlice] = None
@dataclass(frozen=True)