import sys
from collections import abc
from itertools import count
from typing import Callable, Optional, Tuple, Union
import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import (
MpDistSamplingWorkerOptions,
RemoteDistSamplingWorkerOptions,
)
from graphlearn_torch.distributed.dist_sampling_producer import DistMpSamplingProducer
from graphlearn_torch.sampler import NodeSamplerInput
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import EdgeType
import gigl.distributed.utils
from gigl.common.logger import Logger
from gigl.distributed.base_dist_loader import BaseDistLoader
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.utils.neighborloader import (
DatasetSchema,
SamplingClusterSetup,
labeled_to_homogeneous,
set_missing_features,
shard_nodes_by_process,
strip_label_edges,
)
from gigl.src.common.types.graph_data import (
NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing
)
from gigl.types.graph import (
DEFAULT_HOMOGENEOUS_EDGE_TYPE,
DEFAULT_HOMOGENEOUS_NODE_TYPE,
)
# When using CPU based inference/training, we default cpu threads for neighborloading on top of the per process parallelism.
[docs]
DEFAULT_NUM_CPU_THREADS = 2
# We don't see logs for graph store mode for whatever reason.
# TOOD(#442): Revert this once the GCP issues are resolved.
[docs]
def flush():
sys.stdout.flush()
sys.stderr.flush()
[docs]
class DistNeighborLoader(BaseDistLoader):
# Counts instantiations of this class, per process.
# This is needed so we can generate unique worker key for each instance, for graph store mode.
# NOTE: This is per-class, not per-instance.
_counter = count(0)
def __init__(
self,
dataset: Union[DistDataset, RemoteDistDataset],
num_neighbors: Union[list[int], dict[EdgeType, list[int]]],
input_nodes: Optional[
Union[
torch.Tensor,
Tuple[NodeType, torch.Tensor],
abc.Mapping[int, torch.Tensor],
Tuple[NodeType, abc.Mapping[int, torch.Tensor]],
]
] = None,
num_workers: int = 1,
batch_size: int = 1,
context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this
local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this
local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this
pin_memory_device: Optional[torch.device] = None,
worker_concurrency: int = 4,
channel_size: str = "4GB",
prefetch_size: Optional[int] = None,
process_start_gap_seconds: float = 60.0,
num_cpu_threads: Optional[int] = None,
shuffle: bool = False,
drop_last: bool = False,
):
"""
Distributed Neighbor Loader.
Takes in some input nodes and samples neighbors from the dataset.
This loader should be used if you do not have any specially sampling needs,
e.g. you need to generate *training* examples for Anchor Based Link Prediction (ABLP) tasks.
Though this loader is useful for generating random negative examples for ABLP training.
Note: We try to adhere to pyg dataloader api as much as possible.
See the following for reference:
https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/loader/node_loader.html#NodeLoader
https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader
Args:
dataset (DistDataset | RemoteDistDataset): The dataset to sample from.
If this is a `RemoteDistDataset`, then we assumed to be in "Graph Store" mode.
num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]):
The number of neighbors to sample for each node in each iteration.
If an entry is set to `-1`, all neighbors will be included.
In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for each individual edge type.
context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process.
local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node.
local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node.
input_nodes (Tensor | Tuple[NodeType, Tensor] | dict[int, Tensor] | Tuple[NodeType, dict[int, Tensor]]):
The nodes to start sampling from.
It is of type `torch.LongTensor` for homogeneous graphs.
If set to `None` for homogeneous settings, all nodes will be considered.
In heterogeneous graphs, this flag must be passed in as a tuple that holds
the node type and node indices. (default: `None`)
For Graph Store mode, this must be a tuple of (NodeType, dict[int, Tensor]) or dict[int, Tensor].
Where each Tensor in the dict is the node ids to sample from, by server.
e.g. {0: [10, 20], 1: [30, 40]} means sample from nodes 10 and 20 on server 0, and nodes 30 and 40 on server 1.
If a Graph Store input (e.g. list[Tensor]) is provided to colocated mode, or colocated input (e.g. Tensor) is provided to Graph Store mode,
then an error will be raised.
num_workers (int): How many workers to use (subprocesses to spwan) for
distributed neighbor sampling of the current process. (default: ``1``).
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
pin_memory_device (str, optional): The target device that the sampled
results should be copied to. If set to ``None``, the device is inferred based off of
(got by ``gigl.distributed.utils.device.get_available_device``). Which uses the
local_process_rank and torch.cuda.device_count() to assign the device. If cuda is not available,
the cpu device will be used. (default: ``None``).
worker_concurrency (int): The max sampling concurrency for each sampling
worker. Load testing has showed that setting worker_concurrency to 4 yields the best performance
for sampling. Although, you may whish to explore higher/lower settings when performance tuning.
(default: `4`).
channel_size (int or str): The shared-memory buffer size (bytes) allocated
for the channel. Can be modified for performance tuning; a good starting point is: ``num_workers * 64MB``
(default: "4GB").
prefetch_size (Optional[int]): Max number of sampled messages to prefetch on the
client side, per server. Only applies to Graph Store mode (remote workers).
Lower values reduce server-side RPC thread contention when multiple loaders
are active concurrently. (default: ``None``).
Only applicable in Graph Store mode.
If supplied and not it Graph Store mode, an error will be raised.
process_start_gap_seconds (float): Delay between each process for initializing neighbor loader. At large scales,
it is recommended to set this value to be between 60 and 120 seconds -- otherwise multiple processes may
attempt to initialize dataloaders at overlapping times, which can cause CPU memory OOM.
num_cpu_threads (Optional[int]): Number of cpu threads PyTorch should use for CPU training/inference
neighbor loading; on top of the per process parallelism.
Defaults to `2` if set to `None` when using cpu training/inference.
shuffle (bool): Whether to shuffle the input nodes. (default: ``False``).
drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``).
"""
# Set self._shutdowned right away, that way if we throw here, and __del__ is called,
# then we can properly clean up and don't get extraneous error messages.
self._shutdowned = True
# Resolve distributed context
runtime = BaseDistLoader.resolve_runtime(
context, local_process_rank, local_process_world_size
)
del context, local_process_rank, local_process_world_size
# Determine mode
if isinstance(dataset, RemoteDistDataset):
self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE
else:
self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED
if prefetch_size is not None:
raise ValueError(
f"prefetch_size must be None when using Colocated mode, received {prefetch_size}"
)
logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}")
self._instance_count = next(self._counter)
device = (
pin_memory_device
if pin_memory_device
else gigl.distributed.utils.get_available_device(
local_process_rank=runtime.local_rank
)
)
# Mode-specific setup
if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED:
assert isinstance(
dataset, DistDataset
), "When using colocated mode, dataset must be a DistDataset."
input_data, worker_options, dataset_schema = self._setup_for_colocated(
input_nodes=input_nodes,
dataset=dataset,
local_rank=runtime.local_rank,
local_world_size=runtime.local_world_size,
device=device,
master_ip_address=runtime.master_ip_address,
node_rank=runtime.node_rank,
node_world_size=runtime.node_world_size,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
channel_size=channel_size,
num_cpu_threads=num_cpu_threads,
)
else:
assert isinstance(
dataset, RemoteDistDataset
), "When using Graph Store mode, dataset must be a RemoteDistDataset."
if prefetch_size is None:
logger.info(f"prefetch_size is not provided, using default of 4")
prefetch_size = 4
input_data, worker_options, dataset_schema = self._setup_for_graph_store(
input_nodes=input_nodes,
dataset=dataset,
num_workers=num_workers,
prefetch_size=prefetch_size,
channel_size=channel_size,
)
# Cleanup temporary process group if needed
if (
runtime.should_cleanup_distributed_context
and torch.distributed.is_initialized()
):
logger.info(
f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__."
)
torch.distributed.destroy_process_group()
# Create SamplingConfig (with patched fanout)
sampling_config = BaseDistLoader.create_sampling_config(
num_neighbors=num_neighbors,
dataset_schema=dataset_schema,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
)
# Build the sampler: a pre-constructed producer for colocated mode,
# or an RPC callable for graph store mode.
if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED:
assert isinstance(dataset, DistDataset)
assert isinstance(worker_options, MpDistSamplingWorkerOptions)
channel = BaseDistLoader.create_colocated_channel(worker_options)
sampler: Union[
DistMpSamplingProducer, Callable[..., int]
] = DistMpSamplingProducer(
dataset,
input_data,
sampling_config,
worker_options,
channel,
)
else:
sampler = GiglDistServer.create_sampling_producer
# Call base class — handles metadata storage and connection initialization
# (including staggered init for colocated mode).
super().__init__(
dataset=dataset,
sampler_input=input_data,
dataset_schema=dataset_schema,
worker_options=worker_options,
sampling_config=sampling_config,
device=device,
runtime=runtime,
sampler=sampler,
process_start_gap_seconds=process_start_gap_seconds,
)
def _setup_for_graph_store(
self,
input_nodes: Optional[
Union[
torch.Tensor,
tuple[NodeType, torch.Tensor],
abc.Mapping[int, torch.Tensor],
tuple[NodeType, abc.Mapping[int, torch.Tensor]],
]
],
dataset: RemoteDistDataset,
num_workers: int,
prefetch_size: int,
channel_size: str,
) -> tuple[list[NodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema]:
if input_nodes is None:
raise ValueError(
f"When using Graph Store mode, input nodes must be provided, received {input_nodes}"
)
elif isinstance(input_nodes, torch.Tensor):
raise ValueError(
f"When using Graph Store mode, input nodes must be of type (abc.Mapping[int, torch.Tensor] | (NodeType, abc.Mapping[int, torch.Tensor]), received {type(input_nodes)}"
)
elif isinstance(input_nodes, tuple) and isinstance(
input_nodes[1], torch.Tensor
):
raise ValueError(
f"When using Graph Store mode, input nodes must be of type (dict[int, torch.Tensor] | (NodeType, dict[int, torch.Tensor])), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})"
)
node_feature_info = dataset.fetch_node_feature_info()
edge_feature_info = dataset.fetch_edge_feature_info()
edge_types = dataset.fetch_edge_types()
node_rank = dataset.cluster_info.compute_node_rank
# Get sampling ports for compute-storage connections.
sampling_ports = dataset.fetch_free_ports_on_storage_cluster(
num_ports=dataset.cluster_info.num_compute_nodes
)
sampling_port = sampling_ports[node_rank]
worker_key = f"compute_rank_{node_rank}_worker_{self._instance_count}"
logger.info(f"Rank {torch.distributed.get_rank()} worker key: {worker_key}")
worker_options = RemoteDistSamplingWorkerOptions(
server_rank=list(range(dataset.cluster_info.num_storage_nodes)),
num_workers=num_workers,
worker_devices=[torch.device("cpu") for i in range(num_workers)],
master_addr=dataset.cluster_info.storage_cluster_master_ip,
buffer_size=channel_size,
master_port=sampling_port,
worker_key=worker_key,
prefetch_size=prefetch_size,
)
logger.info(
f"Rank {torch.distributed.get_rank()}! init for sampling rpc: {f'tcp://{dataset.cluster_info.storage_cluster_master_ip}:{sampling_port}'}"
)
# Setup input data for the dataloader.
# Determine nodes list and fallback input_type based on input_nodes structure
if isinstance(input_nodes, abc.Mapping):
nodes = input_nodes
fallback_input_type = None
require_edge_feature_info = False
elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping):
nodes = input_nodes[1]
fallback_input_type = input_nodes[0]
require_edge_feature_info = True
else:
raise ValueError(
f"When using Graph Store mode, input nodes must be of type (abc.Mapping[int, torch.Tensor] | (NodeType, abc.Mapping[int, torch.Tensor])), received {type(input_nodes)}"
)
# Determine input_type based on edge_feature_info
if isinstance(edge_types, list):
if DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_types:
input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE
else:
input_type = fallback_input_type
elif require_edge_feature_info:
raise ValueError(
"When using Graph Store mode, edge types must be provided for heterogeneous graphs."
)
else:
input_type = None
is_homogeneous_with_labeled_edge_type = (
input_type == DEFAULT_HOMOGENEOUS_NODE_TYPE
)
# Convert from dict to list which is what the GLT DistNeighborLoader expects.
servers = nodes.keys()
if max(servers) >= dataset.cluster_info.num_storage_nodes or min(servers) < 0:
raise ValueError(
f"When using Graph Store mode, the server ranks must be in range [0, num_servers ({dataset.cluster_info.num_storage_nodes})), received inputs for servers: {list(nodes.keys())}"
)
input_data: list[NodeSamplerInput] = []
for server_rank in range(dataset.cluster_info.num_storage_nodes):
if server_rank in nodes:
input_data.append(
NodeSamplerInput(node=nodes[server_rank], input_type=input_type)
)
else:
input_data.append(
NodeSamplerInput(
node=torch.empty(0, dtype=torch.long), input_type=input_type
)
)
return (
input_data,
worker_options,
DatasetSchema(
is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type,
edge_types=edge_types,
node_feature_info=node_feature_info,
edge_feature_info=edge_feature_info,
edge_dir=dataset.fetch_edge_dir(),
),
)
def _setup_for_colocated(
self,
input_nodes: Optional[
Union[
torch.Tensor,
Tuple[NodeType, torch.Tensor],
abc.Mapping[int, torch.Tensor],
Tuple[NodeType, abc.Mapping[int, torch.Tensor]],
]
],
dataset: DistDataset,
local_rank: int,
local_world_size: int,
device: torch.device,
master_ip_address: str,
node_rank: int,
node_world_size: int,
num_workers: int,
worker_concurrency: int,
channel_size: str,
num_cpu_threads: Optional[int],
) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema]:
if input_nodes is None:
if dataset.node_ids is None:
raise ValueError(
"Dataset must have node ids if input_nodes are not provided."
)
if isinstance(dataset.node_ids, abc.Mapping):
raise ValueError(
f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}"
)
input_nodes = dataset.node_ids
if isinstance(input_nodes, abc.Mapping):
raise ValueError(
f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)}"
)
elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping):
raise ValueError(
f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})"
)
is_homogeneous_with_labeled_edge_type = False
if isinstance(input_nodes, torch.Tensor):
node_ids = input_nodes
# If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting,
# if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE.
if isinstance(dataset.node_ids, abc.Mapping):
if (
len(dataset.node_ids) == 1
and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids
):
node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
is_homogeneous_with_labeled_edge_type = True
else:
raise ValueError(
f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}"
)
else:
node_type = None
else:
node_type, node_ids = input_nodes
assert isinstance(
dataset.node_ids, abc.Mapping
), "Dataset must be heterogeneous if provided input nodes are a tuple."
curr_process_nodes = shard_nodes_by_process(
input_nodes=node_ids,
local_process_rank=local_rank,
local_process_world_size=local_world_size,
)
input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type)
# Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize
# the memory overhead and CPU contention.
logger.info(
f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} using device: {device}"
)
should_use_cpu_workers = device.type == "cpu"
if should_use_cpu_workers and num_cpu_threads is None:
logger.info(
"Using CPU workers, but found num_cpu_threads to be None. "
f"Will default setting num_cpu_threads to {DEFAULT_NUM_CPU_THREADS}."
)
num_cpu_threads = DEFAULT_NUM_CPU_THREADS
neighbor_loader_ports = gigl.distributed.utils.get_free_ports_from_master_node(
num_ports=local_world_size
)
neighbor_loader_port_for_current_rank = neighbor_loader_ports[local_rank]
gigl.distributed.utils.init_neighbor_loader_worker(
master_ip_address=master_ip_address,
local_process_rank=local_rank,
local_process_world_size=local_world_size,
rank=node_rank,
world_size=node_world_size,
master_worker_port=neighbor_loader_port_for_current_rank,
device=device,
should_use_cpu_workers=should_use_cpu_workers,
# Lever to explore tuning for CPU based inference
num_cpu_threads=num_cpu_threads,
)
logger.info(
f"Finished initializing neighbor loader worker: {local_rank}/{local_world_size}"
)
# Sets up worker options for the dataloader
dist_sampling_ports = gigl.distributed.utils.get_free_ports_from_master_node(
num_ports=local_world_size
)
dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank]
worker_options = MpDistSamplingWorkerOptions(
num_workers=num_workers,
worker_devices=[torch.device("cpu") for _ in range(num_workers)],
worker_concurrency=worker_concurrency,
# Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group
# need to be connected. Thus, we need master ip address and master port to
# initate the connection.
# Note that different groups of workers are independent, and thus
# the sampling processes in different groups should be independent, and should
# use different master ports.
master_addr=master_ip_address,
master_port=dist_sampling_port_for_current_rank,
# Load testing show that when num_rpc_threads exceed 16, the performance
# will degrade.
num_rpc_threads=min(dataset.num_partitions, 16),
rpc_timeout=600,
channel_size=channel_size,
pin_memory=device.type == "cuda",
)
if isinstance(dataset.graph, dict):
edge_types = list(dataset.graph.keys())
else:
edge_types = None
return (
input_data,
worker_options,
DatasetSchema(
is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type,
edge_types=edge_types,
node_feature_info=dataset.node_feature_info,
edge_feature_info=dataset.edge_feature_info,
edge_dir=dataset.edge_dir,
),
)
def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
data = super()._collate_fn(msg)
data = set_missing_features(
data=data,
node_feature_info=self._node_feature_info,
edge_feature_info=self._edge_feature_info,
device=self.to_device,
)
if isinstance(data, HeteroData):
data = strip_label_edges(data)
if self._is_homogeneous_with_labeled_edge_type:
data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data)
return data