# Significant portions of this file are taken from GraphLearn-for-PyTorch
# (graphlearn_torch/python/distributed/dist_sampling_producer.py).
# This version uses GiGL's DistNeighborSampler (which supports both standard
# neighbor sampling and ABLP) instead of GLT's DistNeighborSampler.
import datetime
import queue
from threading import Barrier
from typing import Optional, Union, cast
import torch
import torch.multiprocessing as mp
from graphlearn_torch.channel import ChannelBase
from graphlearn_torch.distributed import (
DistDataset,
DistMpSamplingProducer,
MpDistSamplingWorkerOptions,
init_rpc,
init_worker_group,
shutdown_rpc,
)
from graphlearn_torch.distributed.dist_sampling_producer import (
MP_STATUS_CHECK_INTERVAL,
MpCommand,
)
from graphlearn_torch.sampler import (
EdgeSamplerInput,
NodeSamplerInput,
SamplingConfig,
SamplingType,
)
from graphlearn_torch.typing import EdgeType
from graphlearn_torch.utils import seed_everything
from torch._C import _set_worker_signal_handlers
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from gigl.common.logger import Logger
from gigl.distributed.dist_dataset import DistDataset as GiglDistDataset
from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler
from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler
from gigl.distributed.sampler_options import (
KHopNeighborSamplerOptions,
PPRSamplerOptions,
SamplerOptions,
)
def _sampling_worker_loop(
rank: int,
data: DistDataset,
sampler_input: Union[NodeSamplerInput, EdgeSamplerInput],
unshuffled_index: Optional[torch.Tensor],
sampling_config: SamplingConfig,
worker_options: MpDistSamplingWorkerOptions,
channel: ChannelBase,
task_queue: mp.Queue,
sampling_completed_worker_count, # mp.Value
mp_barrier: Barrier,
sampler_options: SamplerOptions,
degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]],
):
dist_sampler = None
try:
init_worker_group(
world_size=worker_options.worker_world_size,
rank=worker_options.worker_ranks[rank],
group_name="_sampling_worker_subprocess",
)
if worker_options.use_all2all:
torch.distributed.init_process_group(
backend="gloo",
timeout=datetime.timedelta(seconds=worker_options.rpc_timeout),
rank=worker_options.worker_ranks[rank],
world_size=worker_options.worker_world_size,
init_method="tcp://{}:{}".format(
worker_options.master_addr, worker_options.master_port
),
)
if worker_options.num_rpc_threads is None:
num_rpc_threads = min(data.num_partitions, 16)
else:
num_rpc_threads = worker_options.num_rpc_threads
current_device = worker_options.worker_devices[rank]
_set_worker_signal_handlers()
torch.set_num_threads(num_rpc_threads + 1)
init_rpc(
master_addr=worker_options.master_addr,
master_port=worker_options.master_port,
num_rpc_threads=num_rpc_threads,
rpc_timeout=worker_options.rpc_timeout,
)
if sampling_config.seed is not None:
seed_everything(sampling_config.seed)
# Shared args for all sampler types (positional args to DistNeighborSampler.__init__)
shared_sampler_args = (
data,
sampling_config.num_neighbors,
sampling_config.with_edge,
sampling_config.with_neg,
sampling_config.with_weight,
sampling_config.edge_dir,
sampling_config.collect_features,
channel,
worker_options.use_all2all,
worker_options.worker_concurrency,
current_device,
)
if isinstance(sampler_options, KHopNeighborSamplerOptions):
dist_sampler = DistNeighborSampler(
*shared_sampler_args,
seed=sampling_config.seed,
)
elif isinstance(sampler_options, PPRSamplerOptions):
assert degree_tensors is not None
dist_sampler = DistPPRNeighborSampler(
*shared_sampler_args,
seed=sampling_config.seed,
alpha=sampler_options.alpha,
eps=sampler_options.eps,
max_ppr_nodes=sampler_options.max_ppr_nodes,
num_neighbors_per_hop=sampler_options.num_neighbors_per_hop,
total_degree_dtype=sampler_options.total_degree_dtype,
degree_tensors=degree_tensors,
)
else:
raise NotImplementedError(
f"Unsupported sampler options type: {type(sampler_options)}"
)
dist_sampler.start_loop()
unshuffled_index_loader: Optional[DataLoader]
loader: DataLoader
if unshuffled_index is not None:
unshuffled_index_loader = DataLoader(
cast(Dataset, unshuffled_index),
batch_size=sampling_config.batch_size,
shuffle=False,
drop_last=sampling_config.drop_last,
)
else:
unshuffled_index_loader = None
mp_barrier.wait()
keep_running = True
while keep_running:
try:
command, args = task_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if command == MpCommand.SAMPLE_ALL:
seeds_index = args
if seeds_index is None:
assert unshuffled_index_loader is not None
loader = unshuffled_index_loader
else:
loader = DataLoader(
seeds_index,
batch_size=sampling_config.batch_size,
shuffle=False,
drop_last=sampling_config.drop_last,
)
if sampling_config.sampling_type == SamplingType.NODE:
for index in loader:
dist_sampler.sample_from_nodes(sampler_input[index])
elif sampling_config.sampling_type == SamplingType.LINK:
for index in loader:
dist_sampler.sample_from_edges(sampler_input[index])
elif sampling_config.sampling_type == SamplingType.SUBGRAPH:
for index in loader:
dist_sampler.subgraph(sampler_input[index])
dist_sampler.wait_all()
with sampling_completed_worker_count.get_lock():
sampling_completed_worker_count.value += (
1 # non-atomic, lock is necessary
)
elif command == MpCommand.STOP:
keep_running = False
else:
raise RuntimeError("Unknown command type")
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
if dist_sampler is not None:
dist_sampler.shutdown_loop()
shutdown_rpc(graceful=False)
[docs]
class DistSamplingProducer(DistMpSamplingProducer):
def __init__(
self,
data: DistDataset,
sampler_input: Union[NodeSamplerInput, EdgeSamplerInput],
sampling_config: SamplingConfig,
worker_options: MpDistSamplingWorkerOptions,
channel: ChannelBase,
sampler_options: SamplerOptions,
):
super().__init__(data, sampler_input, sampling_config, worker_options, channel)
self._sampler_options = sampler_options
[docs]
def init(self):
r"""Create the subprocess pool. Init samplers and rpc server."""
# Extract degree tensors before spawning workers. Worker subprocesses
# only initialize RPC (not torch.distributed), so the lazy degree
# computation on GiglDistDataset would fail there. Computing here —
# where torch.distributed IS initialized — lets the tensor be shared
# to workers via IPC.
degree_tensors: Optional[
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
] = None
if isinstance(self._sampler_options, PPRSamplerOptions):
assert isinstance(self.data, GiglDistDataset)
degree_tensors = self.data.degree_tensor
if isinstance(degree_tensors, dict):
logger.info(
f"Pre-computed degree tensors for PPR sampling across {len(degree_tensors)} edge types."
)
else:
logger.info(
f"Pre-computed degree tensor for PPR sampling with {degree_tensors.size(0)} nodes."
)
if self.sampling_config.seed is not None:
seed_everything(self.sampling_config.seed)
if not self.sampling_config.shuffle:
unshuffled_indexes = self._get_seeds_indexes()
else:
unshuffled_indexes = [None] * self.num_workers
mp_context = mp.get_context("spawn")
barrier = mp_context.Barrier(self.num_workers + 1)
for rank in range(self.num_workers):
task_queue = mp_context.Queue(
self.num_workers * self.worker_options.worker_concurrency
)
self._task_queues.append(task_queue)
w = mp_context.Process(
target=_sampling_worker_loop,
args=(
rank,
self.data,
self.sampler_input,
unshuffled_indexes[rank],
self.sampling_config,
self.worker_options,
self.output_channel,
task_queue,
self.sampling_completed_worker_count,
barrier,
self._sampler_options,
degree_tensors,
),
)
w.daemon = True
w.start()
self._workers.append(w)
barrier.wait()