Source code for gigl.distributed.dist_sampling_producer

# Significant portions of this file are taken from GraphLearn-for-PyTorch
# (graphlearn_torch/python/distributed/dist_sampling_producer.py).
# This version uses GiGL's sampler hierarchy (BaseGiGLSampler subclasses:
# DistNeighborSampler for k-hop, DistPPRNeighborSampler for PPR) instead of
# GLT's DistNeighborSampler directly.

import datetime
import queue
from threading import Barrier
from typing import Optional, Union, cast

import torch
import torch.multiprocessing as mp
from graphlearn_torch.channel import ChannelBase
from graphlearn_torch.distributed import (
    DistDataset,
    DistMpSamplingProducer,
    MpDistSamplingWorkerOptions,
    init_rpc,
    init_worker_group,
    shutdown_rpc,
)
from graphlearn_torch.distributed.dist_sampling_producer import (
    MP_STATUS_CHECK_INTERVAL,
    MpCommand,
)
from graphlearn_torch.sampler import (
    EdgeSamplerInput,
    NodeSamplerInput,
    SamplingConfig,
    SamplingType,
)
from graphlearn_torch.typing import EdgeType
from graphlearn_torch.utils import seed_everything
from torch._C import _set_worker_signal_handlers
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset

from gigl.common.logger import Logger
from gigl.distributed.sampler_options import SamplerOptions
from gigl.distributed.utils.dist_sampler import create_dist_sampler

[docs] logger = Logger()
def _sampling_worker_loop( rank: int, data: DistDataset, sampler_input: Union[NodeSamplerInput, EdgeSamplerInput], unshuffled_index: Optional[torch.Tensor], sampling_config: SamplingConfig, worker_options: MpDistSamplingWorkerOptions, channel: ChannelBase, task_queue: mp.Queue, sampling_completed_worker_count, # mp.Value mp_barrier: Barrier, sampler_options: SamplerOptions, degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], ): dist_sampler = None try: init_worker_group( world_size=worker_options.worker_world_size, rank=worker_options.worker_ranks[rank], group_name="_sampling_worker_subprocess", ) if worker_options.use_all2all: torch.distributed.init_process_group( backend="gloo", timeout=datetime.timedelta(seconds=worker_options.rpc_timeout), rank=worker_options.worker_ranks[rank], world_size=worker_options.worker_world_size, init_method="tcp://{}:{}".format( worker_options.master_addr, worker_options.master_port ), ) if worker_options.num_rpc_threads is None: num_rpc_threads = min(data.num_partitions, 16) else: num_rpc_threads = worker_options.num_rpc_threads current_device = worker_options.worker_devices[rank] _set_worker_signal_handlers() torch.set_num_threads(num_rpc_threads + 1) init_rpc( master_addr=worker_options.master_addr, master_port=worker_options.master_port, num_rpc_threads=num_rpc_threads, rpc_timeout=worker_options.rpc_timeout, ) if sampling_config.seed is not None: seed_everything(sampling_config.seed) dist_sampler = create_dist_sampler( data=data, sampling_config=sampling_config, worker_options=worker_options, channel=channel, sampler_options=sampler_options, degree_tensors=degree_tensors, current_device=current_device, ) dist_sampler.start_loop() unshuffled_index_loader: Optional[DataLoader] loader: DataLoader if unshuffled_index is not None: unshuffled_index_loader = DataLoader( cast(Dataset, unshuffled_index), batch_size=sampling_config.batch_size, shuffle=False, drop_last=sampling_config.drop_last, ) else: unshuffled_index_loader = None mp_barrier.wait() keep_running = True while keep_running: try: command, args = task_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if command == MpCommand.SAMPLE_ALL: seeds_index = args if seeds_index is None: assert unshuffled_index_loader is not None loader = unshuffled_index_loader else: loader = DataLoader( seeds_index, batch_size=sampling_config.batch_size, shuffle=False, drop_last=sampling_config.drop_last, ) if sampling_config.sampling_type == SamplingType.NODE: for index in loader: dist_sampler.sample_from_nodes(sampler_input[index]) elif sampling_config.sampling_type == SamplingType.LINK: for index in loader: dist_sampler.sample_from_edges(sampler_input[index]) elif sampling_config.sampling_type == SamplingType.SUBGRAPH: for index in loader: dist_sampler.subgraph(sampler_input[index]) dist_sampler.wait_all() with sampling_completed_worker_count.get_lock(): sampling_completed_worker_count.value += ( 1 # non-atomic, lock is necessary ) elif command == MpCommand.STOP: keep_running = False else: raise RuntimeError("Unknown command type") except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass if dist_sampler is not None: dist_sampler.shutdown_loop() shutdown_rpc(graceful=False)
[docs] class DistSamplingProducer(DistMpSamplingProducer): def __init__( self, data: DistDataset, sampler_input: Union[NodeSamplerInput, EdgeSamplerInput], sampling_config: SamplingConfig, worker_options: MpDistSamplingWorkerOptions, channel: ChannelBase, sampler_options: SamplerOptions, degree_tensors: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = None, ): super().__init__(data, sampler_input, sampling_config, worker_options, channel) self._sampler_options = sampler_options self._degree_tensors = degree_tensors
[docs] def init(self): r"""Create the subprocess pool. Init samplers and rpc server.""" if self.sampling_config.seed is not None: seed_everything(self.sampling_config.seed) if not self.sampling_config.shuffle: unshuffled_indexes = self._get_seeds_indexes() else: unshuffled_indexes = [None] * self.num_workers mp_context = mp.get_context("spawn") barrier = mp_context.Barrier(self.num_workers + 1) for rank in range(self.num_workers): task_queue = mp_context.Queue( self.num_workers * self.worker_options.worker_concurrency ) self._task_queues.append(task_queue) worker = mp_context.Process( target=_sampling_worker_loop, args=( rank, self.data, self.sampler_input, unshuffled_indexes[rank], self.sampling_config, self.worker_options, self.output_channel, task_queue, self.sampling_completed_worker_count, barrier, self._sampler_options, self._degree_tensors, ), ) worker.daemon = True worker.start() self._workers.append(worker) barrier.wait()