import sys
from collections import abc
from itertools import count
from typing import Callable, Optional, Tuple, Union
import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import (
MpDistSamplingWorkerOptions,
RemoteDistSamplingWorkerOptions,
)
from graphlearn_torch.sampler import NodeSamplerInput
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import EdgeType
import gigl.distributed.utils
from gigl.common.logger import Logger
from gigl.distributed.base_dist_loader import BaseDistLoader
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.dist_ppr_sampler import (
PPR_EDGE_INDEX_METADATA_KEY,
PPR_WEIGHT_METADATA_KEY,
)
from gigl.distributed.dist_sampling_producer import DistSamplingProducer
from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.sampler_options import (
PPRSamplerOptions,
SamplerOptions,
resolve_sampler_options,
)
from gigl.distributed.utils.neighborloader import (
DatasetSchema,
SamplingClusterSetup,
attach_ppr_outputs,
extract_edge_type_metadata,
extract_metadata,
labeled_to_homogeneous,
set_missing_features,
shard_nodes_by_process,
strip_label_edges,
strip_non_ppr_edge_types,
)
from gigl.src.common.types.graph_data import (
NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing
)
from gigl.types.graph import (
DEFAULT_HOMOGENEOUS_EDGE_TYPE,
DEFAULT_HOMOGENEOUS_NODE_TYPE,
)
# We don't see logs for graph store mode for whatever reason.
# TOOD(#442): Revert this once the GCP issues are resolved.
[docs]
def flush():
sys.stdout.flush()
sys.stderr.flush()
[docs]
class DistNeighborLoader(BaseDistLoader):
# Counts instantiations of this class, per process.
# This is needed so we can generate unique worker key for each instance, for graph store mode.
# NOTE: This is per-class, not per-instance.
_counter = count(0)
def __init__(
self,
dataset: Union[DistDataset, RemoteDistDataset],
num_neighbors: Union[list[int], dict[EdgeType, list[int]]],
input_nodes: Optional[
Union[
torch.Tensor,
Tuple[NodeType, torch.Tensor],
abc.Mapping[int, torch.Tensor],
Tuple[NodeType, abc.Mapping[int, torch.Tensor]],
]
] = None,
num_workers: int = 1,
batch_size: int = 1,
context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this
local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this
local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this
pin_memory_device: Optional[torch.device] = None,
worker_concurrency: int = 4,
channel_size: str = "4GB",
prefetch_size: Optional[int] = None,
process_start_gap_seconds: float = 60.0,
max_concurrent_producer_inits: Optional[int] = None,
num_cpu_threads: Optional[int] = None,
shuffle: bool = False,
drop_last: bool = False,
sampler_options: Optional[SamplerOptions] = None,
non_blocking_transfers: bool = True,
):
"""
Distributed Neighbor Loader.
Takes in some input nodes and samples neighbors from the dataset.
This loader should be used if you do not have any specially sampling needs,
e.g. you need to generate *training* examples for Anchor Based Link Prediction (ABLP) tasks.
Though this loader is useful for generating random negative examples for ABLP training.
Note: We try to adhere to pyg dataloader api as much as possible.
See the following for reference:
https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/loader/node_loader.html#NodeLoader
https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader
Args:
dataset (DistDataset | RemoteDistDataset): The dataset to sample from.
If this is a `RemoteDistDataset`, then we assumed to be in "Graph Store" mode.
num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]):
The number of neighbors to sample for each node in each iteration.
If an entry is set to `-1`, all neighbors will be included.
In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for each individual edge type.
If ``KHopNeighborSamplerOptions`` is also provided, they must match.
context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process.
local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node.
local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node.
input_nodes (Tensor | Tuple[NodeType, Tensor] | dict[int, Tensor] | Tuple[NodeType, dict[int, Tensor]]):
The nodes to start sampling from.
It is of type `torch.LongTensor` for homogeneous graphs.
If set to `None` for homogeneous settings, all nodes will be considered.
In heterogeneous graphs, this flag must be passed in as a tuple that holds
the node type and node indices. (default: `None`)
For Graph Store mode, this must be a tuple of (NodeType, dict[int, Tensor]) or dict[int, Tensor].
Where each Tensor in the dict is the node ids to sample from, by server.
e.g. {0: [10, 20], 1: [30, 40]} means sample from nodes 10 and 20 on server 0, and nodes 30 and 40 on server 1.
If a Graph Store input (e.g. list[Tensor]) is provided to colocated mode, or colocated input (e.g. Tensor) is provided to Graph Store mode,
then an error will be raised.
num_workers (int): How many workers to use (subprocesses to spwan) for
distributed neighbor sampling of the current process. (default: ``1``).
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
pin_memory_device (str, optional): The target device that the sampled
results should be copied to. If set to ``None``, the device is inferred based off of
(got by ``gigl.distributed.utils.device.get_available_device``). Which uses the
local_process_rank and torch.cuda.device_count() to assign the device. If cuda is not available,
the cpu device will be used. (default: ``None``).
worker_concurrency (int): The max sampling concurrency for each sampling
worker. Load testing has showed that setting worker_concurrency to 4 yields the best performance
for sampling. Although, you may whish to explore higher/lower settings when performance tuning.
(default: `4`).
channel_size (int or str): The shared-memory buffer size (bytes) allocated
for the channel. Can be modified for performance tuning; a good starting point is: ``num_workers * 64MB``
(default: "4GB").
prefetch_size (Optional[int]): Max number of sampled messages to prefetch on the
client side, per server. Only applies to Graph Store mode (remote workers).
Lower values reduce server-side RPC thread contention when multiple loaders
are active concurrently. (default: ``None``).
Only applicable in Graph Store mode.
If supplied and not it Graph Store mode, an error will be raised.
process_start_gap_seconds (float): Delay between each process for initializing neighbor loader.
In colocated mode, each process sleeps ``local_rank * process_start_gap_seconds``
before initializing. In graph store mode, leader ranks are grouped into batches
of ``max_concurrent_producer_inits`` and each batch sleeps
``batch_index * process_start_gap_seconds`` before dispatching RPCs.
max_concurrent_producer_inits (int): Maximum number of leader ranks that may
dispatch create-producer RPCs concurrently in graph store mode. Leaders are
grouped into batches of this size; each batch is staggered by
``process_start_gap_seconds``. Only applies to graph store mode.
Defaults to ``None`` (no staggering).
num_cpu_threads (Optional[int]): Number of cpu threads PyTorch should use for CPU training/inference
neighbor loading; on top of the per process parallelism.
Defaults to `2` if set to `None` when using cpu training/inference.
shuffle (bool): Whether to shuffle the input nodes. (default: ``False``).
drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``).
sampler_options (Optional[SamplerOptions]): Controls which sampler class is
instantiated. Pass ``KHopNeighborSamplerOptions`` to use the built-in sampler,
or ``CustomSamplerOptions`` to dynamically import a custom sampler class.
If ``None``, defaults to ``KHopNeighborSamplerOptions(num_neighbors)``.
non_blocking_transfers (bool): If True (default), batch-transfers all
sampled tensors to the target CUDA device using non-blocking copies
before collation, which can overlap data transfer with computation
when source tensors reside in pinned memory. If False, the bulk
transfer is skipped and GLT's default (blocking) device placement
is used instead.
See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html
for background on pinned memory and non-blocking transfers.
"""
# Set self._shutdowned right away, that way if we throw here, and __del__ is called,
# then we can properly clean up and don't get extraneous error messages.
self._shutdowned = True
sampler_options = resolve_sampler_options(num_neighbors, sampler_options)
# Resolve distributed context
runtime = BaseDistLoader.resolve_runtime(
context, local_process_rank, local_process_world_size
)
del context, local_process_rank, local_process_world_size
# Determine mode
if isinstance(dataset, RemoteDistDataset):
self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE
else:
self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED
if prefetch_size is not None:
raise ValueError(
f"prefetch_size must be None when using Colocated mode, received {prefetch_size}"
)
if max_concurrent_producer_inits is not None:
raise ValueError(
f"max_concurrent_producer_inits must be None when using Colocated mode, received {max_concurrent_producer_inits}"
)
logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}")
self._instance_count = next(self._counter)
device = (
pin_memory_device
if pin_memory_device
else gigl.distributed.utils.get_available_device(
local_process_rank=runtime.local_rank
)
)
# Mode-specific setup
if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED:
assert isinstance(
dataset, DistDataset
), "When using colocated mode, dataset must be a DistDataset."
input_data, worker_options, dataset_schema = self._setup_for_colocated(
input_nodes=input_nodes,
dataset=dataset,
local_rank=runtime.local_rank,
local_world_size=runtime.local_world_size,
device=device,
master_ip_address=runtime.master_ip_address,
node_rank=runtime.node_rank,
node_world_size=runtime.node_world_size,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
channel_size=channel_size,
num_cpu_threads=num_cpu_threads,
)
else:
assert isinstance(
dataset, RemoteDistDataset
), "When using Graph Store mode, dataset must be a RemoteDistDataset."
if prefetch_size is None:
logger.info(f"prefetch_size is not provided, using default of 4")
prefetch_size = 4
input_data, worker_options, dataset_schema = self._setup_for_graph_store(
input_nodes=input_nodes,
dataset=dataset,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
prefetch_size=prefetch_size,
channel_size=channel_size,
)
# Cleanup temporary process group if needed
if (
runtime.should_cleanup_distributed_context
and torch.distributed.is_initialized()
):
logger.info(
f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__."
)
torch.distributed.destroy_process_group()
# Create SamplingConfig (with patched fanout)
sampling_config = BaseDistLoader.create_sampling_config(
num_neighbors=num_neighbors,
dataset_schema=dataset_schema,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
)
# Build the producer: a pre-constructed producer for colocated mode,
# or an RPC callable for graph store mode.
if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED:
assert isinstance(dataset, DistDataset)
assert isinstance(worker_options, MpDistSamplingWorkerOptions)
channel = BaseDistLoader.create_colocated_channel(worker_options)
producer: Union[
DistSamplingProducer, Callable[..., int]
] = DistSamplingProducer(
data=dataset,
sampler_input=input_data,
sampling_config=sampling_config,
worker_options=worker_options,
channel=channel,
sampler_options=sampler_options,
)
else:
producer = GiglDistServer.create_sampling_producer
# Call base class — handles metadata storage and connection initialization
# (including staggered init for colocated mode).
super().__init__(
dataset=dataset,
sampler_input=input_data,
dataset_schema=dataset_schema,
worker_options=worker_options,
sampling_config=sampling_config,
device=device,
runtime=runtime,
producer=producer,
sampler_options=sampler_options,
process_start_gap_seconds=process_start_gap_seconds,
max_concurrent_producer_inits=max_concurrent_producer_inits,
non_blocking_transfers=non_blocking_transfers,
)
def _setup_for_graph_store(
self,
input_nodes: Optional[
Union[
torch.Tensor,
tuple[NodeType, torch.Tensor],
abc.Mapping[int, torch.Tensor],
tuple[NodeType, abc.Mapping[int, torch.Tensor]],
]
],
dataset: RemoteDistDataset,
num_workers: int,
worker_concurrency: int,
prefetch_size: int,
channel_size: str,
) -> tuple[list[NodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema]:
if input_nodes is None:
raise ValueError(
f"When using Graph Store mode, input nodes must be provided, received {input_nodes}"
)
elif isinstance(input_nodes, torch.Tensor):
raise ValueError(
f"When using Graph Store mode, input nodes must be of type (abc.Mapping[int, torch.Tensor] | (NodeType, abc.Mapping[int, torch.Tensor]), received {type(input_nodes)}"
)
elif isinstance(input_nodes, tuple) and isinstance(
input_nodes[1], torch.Tensor
):
raise ValueError(
f"When using Graph Store mode, input nodes must be of type (dict[int, torch.Tensor] | (NodeType, dict[int, torch.Tensor])), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})"
)
node_feature_info = dataset.fetch_node_feature_info()
edge_feature_info = dataset.fetch_edge_feature_info()
edge_types = dataset.fetch_edge_types()
compute_rank = torch.distributed.get_rank()
worker_key = f"compute_rank_{compute_rank}_worker_{self._instance_count}"
logger.info(f"Rank {compute_rank} worker key: {worker_key}")
worker_options = BaseDistLoader.create_graph_store_worker_options(
dataset=dataset,
compute_rank=compute_rank,
worker_key=worker_key,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
channel_size=channel_size,
prefetch_size=prefetch_size,
)
logger.info(
f"Rank {torch.distributed.get_rank()}! init for sampling rpc: "
f"tcp://{worker_options.master_addr}:{worker_options.master_port}"
)
# Setup input data for the dataloader.
# Determine nodes list and fallback input_type based on input_nodes structure
if isinstance(input_nodes, abc.Mapping):
nodes = input_nodes
fallback_input_type = None
require_edge_feature_info = False
elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping):
nodes = input_nodes[1]
fallback_input_type = input_nodes[0]
require_edge_feature_info = True
else:
raise ValueError(
f"When using Graph Store mode, input nodes must be of type (abc.Mapping[int, torch.Tensor] | (NodeType, abc.Mapping[int, torch.Tensor])), received {type(input_nodes)}"
)
# Determine input_type based on edge_feature_info
if isinstance(edge_types, list):
if DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_types:
input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE
else:
input_type = fallback_input_type
elif require_edge_feature_info:
raise ValueError(
"When using Graph Store mode, edge types must be provided for heterogeneous graphs."
)
else:
input_type = None
is_homogeneous_with_labeled_edge_type = (
input_type == DEFAULT_HOMOGENEOUS_NODE_TYPE
)
# Convert from dict to list which is what the GLT DistNeighborLoader expects.
servers = nodes.keys()
if max(servers) >= dataset.cluster_info.num_storage_nodes or min(servers) < 0:
raise ValueError(
f"When using Graph Store mode, the server ranks must be in range [0, num_servers ({dataset.cluster_info.num_storage_nodes})), received inputs for servers: {list(nodes.keys())}"
)
input_data: list[NodeSamplerInput] = []
for server_rank in range(dataset.cluster_info.num_storage_nodes):
if server_rank in nodes:
input_data.append(
NodeSamplerInput(node=nodes[server_rank], input_type=input_type)
)
else:
input_data.append(
NodeSamplerInput(
node=torch.empty(0, dtype=torch.long), input_type=input_type
)
)
return (
input_data,
worker_options,
DatasetSchema(
is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type,
edge_types=edge_types,
node_feature_info=node_feature_info,
edge_feature_info=edge_feature_info,
edge_dir=dataset.fetch_edge_dir(),
),
)
def _setup_for_colocated(
self,
input_nodes: Optional[
Union[
torch.Tensor,
Tuple[NodeType, torch.Tensor],
abc.Mapping[int, torch.Tensor],
Tuple[NodeType, abc.Mapping[int, torch.Tensor]],
]
],
dataset: DistDataset,
local_rank: int,
local_world_size: int,
device: torch.device,
master_ip_address: str,
node_rank: int,
node_world_size: int,
num_workers: int,
worker_concurrency: int,
channel_size: str,
num_cpu_threads: Optional[int],
) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema]:
if input_nodes is None:
if dataset.node_ids is None:
raise ValueError(
"Dataset must have node ids if input_nodes are not provided."
)
if isinstance(dataset.node_ids, abc.Mapping):
raise ValueError(
f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}"
)
input_nodes = dataset.node_ids
if isinstance(input_nodes, abc.Mapping):
raise ValueError(
f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)}"
)
elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping):
raise ValueError(
f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})"
)
is_homogeneous_with_labeled_edge_type = False
if isinstance(input_nodes, torch.Tensor):
node_ids = input_nodes
# If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting,
# if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE.
if isinstance(dataset.node_ids, abc.Mapping):
if (
len(dataset.node_ids) == 1
and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids
):
node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
is_homogeneous_with_labeled_edge_type = True
else:
raise ValueError(
f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}"
)
else:
node_type = None
else:
node_type, node_ids = input_nodes
assert isinstance(
dataset.node_ids, abc.Mapping
), "Dataset must be heterogeneous if provided input nodes are a tuple."
curr_process_nodes = shard_nodes_by_process(
input_nodes=node_ids,
local_process_rank=local_rank,
local_process_world_size=local_world_size,
)
input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type)
BaseDistLoader.initialize_colocated_sampling_worker(
local_rank=local_rank,
local_world_size=local_world_size,
node_rank=node_rank,
node_world_size=node_world_size,
master_ip_address=master_ip_address,
device=device,
num_cpu_threads=num_cpu_threads,
)
# Sets up worker options for the dataloader
dist_sampling_ports = gigl.distributed.utils.get_free_ports_from_master_node(
num_ports=local_world_size
)
dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank]
worker_options = BaseDistLoader.create_colocated_worker_options(
dataset_num_partitions=dataset.num_partitions,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
master_ip_address=master_ip_address,
master_port=dist_sampling_port_for_current_rank,
channel_size=channel_size,
pin_memory=device.type == "cuda",
)
if isinstance(dataset.graph, dict):
edge_types = list(dataset.graph.keys())
else:
edge_types = None
return (
input_data,
worker_options,
DatasetSchema(
is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type,
edge_types=edge_types,
node_feature_info=dataset.node_feature_info,
edge_feature_info=dataset.edge_feature_info,
edge_dir=dataset.edge_dir,
),
)
def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
# Extract user-defined metadata before super()._collate_fn, which
# calls GLT's to_hetero_data. to_hetero_data misinterprets #META. keys
# as edge types and fails when edge_dir="out" (tries to call
# reverse_edge_type on them). We strip them here and re-apply after.
# TODO (mkolodner-sc): Remove once GLT's to_hetero_data is fixed.
metadata, stripped_msg = extract_metadata(msg, self.to_device)
data = super()._collate_fn(stripped_msg)
data = set_missing_features(
data=data,
node_feature_info=self._node_feature_info,
edge_feature_info=self._edge_feature_info,
device=self.to_device,
)
if isinstance(data, HeteroData):
data = strip_label_edges(data)
if self._is_homogeneous_with_labeled_edge_type:
data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data)
if isinstance(self._sampler_options, PPRSamplerOptions):
matched, metadata = extract_edge_type_metadata(
metadata=metadata,
prefixes=[PPR_EDGE_INDEX_METADATA_KEY, PPR_WEIGHT_METADATA_KEY],
)
ppr_edge_indices = matched[PPR_EDGE_INDEX_METADATA_KEY]
ppr_weights = matched[PPR_WEIGHT_METADATA_KEY]
attach_ppr_outputs(data, ppr_edge_indices, ppr_weights)
if isinstance(data, HeteroData):
data = strip_non_ppr_edge_types(data, set(ppr_edge_indices.keys()))
# Attach any remaining metadata (e.g. custom user-defined keys) directly onto the
# data object so downstream code can access them via attribute lookup.
for key, value in metadata.items():
data[key] = value
return data