Source code for gigl.distributed.graph_store.remote_channel

"""GiGL-owned remote receiving channel for graph-store sampling.

This mirrors GLT's ``RemoteReceivingChannel`` [1] behavior, but routes fetch RPCs
through GiGL server methods so channel-based sampling works with shared
producers.

[1] https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/channel/remote_channel.py
"""

from __future__ import annotations

import queue
import time
from collections import abc
from typing import Callable, Optional, Union

import torch
from graphlearn_torch.channel import ChannelBase, SampleMessage

from gigl.common.logger import Logger
from gigl.distributed.graph_store.compute import async_request_server
from gigl.distributed.graph_store.dist_server import DistServer

[docs] logger = Logger()
[docs] class RemoteReceivingChannel(ChannelBase): def __init__( self, server_rank: Union[int, list[int]], channel_id: Union[int, list[int]], prefetch_size: int = 2, active_mask: Optional[list[bool]] = None, pin_memory: bool = False, ) -> None: """Pull-based receiving channel that fetches sampled messages from servers. Args: server_rank: Target storage server rank(s). channel_id: Sampling channel id(s), one per server rank. prefetch_size: Number of in-flight fetch requests per server. active_mask: Optional per-server mask indicating which channels can produce at least one batch this epoch. Inactive servers are treated as already finished and are never polled. pin_memory: If True, copy received tensors to CUDA-pinned host memory before returning from ``recv()``. Enables faster GPU transfers via DMA in the downstream collate function. Differences from GLT's ``RemoteReceivingChannel.__init__``: - ``active_mask`` parameter: marks servers with no data this epoch so they are never polled; ``server_end_of_epoch`` is initialized from this mask instead of all-``False``. - ``pin_memory`` parameter: when ``True``, ``recv()`` copies tensors to CUDA-pinned host memory for faster DMA-based GPU transfers. - Explicit ``ValueError`` validation replaces bare ``assert`` for length checks on ``server_rank`` / ``channel_id`` / ``active_mask``. - Recv-count logging state (``_recv_count``, ``_log_every_n``) supports periodic timing telemetry added to ``recv()``. - Typed queue: ``Queue[tuple[Optional[SampleMessage], bool, int]]``. """ self._server_rank_list = ( list(server_rank) if isinstance(server_rank, abc.Sequence) and not isinstance(server_rank, int) else [int(server_rank)] ) self._channel_id_list = ( list(channel_id) if isinstance(channel_id, abc.Sequence) and not isinstance(channel_id, int) else [int(channel_id)] ) self._prefetch_size = prefetch_size if len(self._server_rank_list) != len(self._channel_id_list): raise ValueError( "server_rank and channel_id must have the same length, got " f"{len(self._server_rank_list)} and {len(self._channel_id_list)}" ) if active_mask is None: self._active_mask = [True] * len(self._server_rank_list) else: if len(active_mask) != len(self._server_rank_list): raise ValueError( "active_mask must have the same length as server_rank/channel_id, got " f"{len(active_mask)} and {len(self._server_rank_list)}" ) self._active_mask = list(active_mask) self._num_request_list = [0] * len(self._server_rank_list) self._num_received_list = [0] * len(self._server_rank_list) self._server_end_of_epoch = [not is_active for is_active in self._active_mask] self._global_end_of_epoch = all(self._server_end_of_epoch) self._queue: queue.Queue[ tuple[Optional[Union[SampleMessage, Exception]], bool, int] ] = queue.Queue(maxsize=self._prefetch_size * len(self._server_rank_list)) self._recv_count: int = 0 self._log_every_n: int = 50 # For some reason calling `pin_memory()` with no CUDA available raises the below error: # No CUDA GPUs are available if not torch.cuda.is_available() and pin_memory: raise ValueError("pin_memory is only supported when CUDA is available") self._pin_memory = pin_memory
[docs] def reset(self) -> None: """Reset all state to start a new epoch.""" while not self._queue.empty(): _ = self._queue.get() self._server_end_of_epoch = [not is_active for is_active in self._active_mask] self._num_request_list = [0] * len(self._server_rank_list) self._num_received_list = [0] * len(self._server_rank_list) self._global_end_of_epoch = all(self._server_end_of_epoch) self._recv_count = 0
[docs] def send(self, msg: SampleMessage, **kwargs: object) -> None: raise RuntimeError( f"'{self.__class__.__name__}': cannot send " "message with a receiving channel." )
[docs] def recv(self, **kwargs: object) -> SampleMessage: """Pull-based receiving channel that fetches sampled messages from servers. This method blocks until a message is available or the epoch ends. Args: kwargs: Additional keyword arguments - unused. Returns: The sampled message. Raises: StopIteration: If the epoch ends and no messages are available. Exception: If the future fails. """ request_some_elapsed = 0.0 num_dispatched = 0 if self._global_end_of_epoch: if self._all_received(): raise StopIteration else: request_some_start = time.monotonic() num_dispatched = self._request_some() request_some_elapsed = time.monotonic() - request_some_start queue_depth = self._queue.qsize() queue_get_start = time.monotonic() msg, end_of_epoch, local_server_idx = self._queue.get() queue_get_elapsed = time.monotonic() - queue_get_start self._num_received_list[local_server_idx] += 1 # Server guarantees that when end_of_epoch is true, msg is None. while end_of_epoch: self._server_end_of_epoch[local_server_idx] = True if sum(self._server_end_of_epoch) == len(self._server_rank_list): self._global_end_of_epoch = True if self._all_received(): raise StopIteration msg, end_of_epoch, local_server_idx = self._queue.get() self._num_received_list[local_server_idx] += 1 self._recv_count += 1 if self._recv_count % self._log_every_n == 0: logger.info( "remote_channel_recv " f"recv_count={self._recv_count} " f"request_some_time={request_some_elapsed:.4f}s " f"num_rpcs_dispatched={num_dispatched} " f"queue_depth_before_get={queue_depth} " f"queue_get_time={queue_get_elapsed:.4f}s " ) if isinstance(msg, Exception): raise msg return msg
@staticmethod def _pin_sample_message(msg: Optional[SampleMessage]) -> Optional[SampleMessage]: """Copy all tensors in the message to CUDA-pinned host memory. This enables faster DMA transfers when subsequently calling ``.to(device)`` in the collate function. See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html# for more details on pin_memory. """ if msg is None: return None pinned: SampleMessage = {} for k, v in msg.items(): pinned[k] = v.pin_memory() return pinned def _all_received(self) -> bool: return sum(self._num_received_list) == sum(self._num_request_list) def _request_some(self) -> int: """Dispatch prefetch RPCs. Returns the number of new RPCs dispatched.""" num_dispatched = 0 def on_done( future: torch.futures.Future[tuple[Optional[SampleMessage], bool]], local_server_idx: int, ) -> None: try: msg, end_of_epoch = future.wait() if self._pin_memory: msg = self._pin_sample_message(msg) self._queue.put((msg, end_of_epoch, local_server_idx)) except Exception as exc: logger.error("broken future of receiving remote messages: %s", exc) self._queue.put((exc, False, local_server_idx)) def create_callback( local_server_idx: int, ) -> Callable[ [torch.futures.Future[tuple[Optional[SampleMessage], bool]]], None ]: def callback( future: torch.futures.Future[tuple[Optional[SampleMessage], bool]], ) -> None: on_done(future, local_server_idx) return callback for local_server_idx, server_rank in enumerate(self._server_rank_list): if not self._active_mask[local_server_idx]: continue if self._server_end_of_epoch[local_server_idx]: continue missing = ( self._num_received_list[local_server_idx] + self._prefetch_size - self._num_request_list[local_server_idx] ) for _ in range(missing): future = async_request_server( server_rank, DistServer.fetch_one_sampled_message, self._channel_id_list[local_server_idx], ) future.add_done_callback(create_callback(local_server_idx)) self._num_request_list[local_server_idx] += 1 num_dispatched += 1 return num_dispatched