Source code for gigl.distributed.distributed_neighborloader

from collections import Counter, abc
from typing import Dict, List, Optional, Tuple, Union

import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import DistLoader, MpDistSamplingWorkerOptions
from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import EdgeType

import gigl.distributed.utils
from gigl.common.logger import Logger
from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_link_prediction_dataset import DistLinkPredictionDataset
from gigl.distributed.utils.neighborloader import (
    labeled_to_homogeneous,
    patch_fanout_for_sampling,
    shard_nodes_by_process,
    strip_label_edges,
)
from gigl.src.common.types.graph_data import (
    NodeType,  # TODO (mkolodner-sc): Change to use torch_geometric.typing
)
from gigl.types.graph import (
    DEFAULT_HOMOGENEOUS_EDGE_TYPE,
    DEFAULT_HOMOGENEOUS_NODE_TYPE,
)

[docs] logger = Logger()
# When using CPU based inference/training, we default cpu threads for neighborloading on top of the per process parallelism.
[docs] DEFAULT_NUM_CPU_THREADS = 2
[docs] class DistNeighborLoader(DistLoader): def __init__( self, dataset: DistLinkPredictionDataset, num_neighbors: Union[List[int], Dict[EdgeType, List[int]]], input_nodes: Optional[ Union[torch.Tensor, Tuple[NodeType, torch.Tensor]] ] = None, num_workers: int = 1, batch_size: int = 1, context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this pin_memory_device: Optional[torch.device] = None, worker_concurrency: int = 4, channel_size: str = "4GB", process_start_gap_seconds: float = 60.0, num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, ): """ Note: We try to adhere to pyg dataloader api as much as possible. See the following for reference: https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/loader/node_loader.html#NodeLoader https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader Args: dataset (DistLinkPredictionDataset): The dataset to sample from. num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. input_nodes (torch.Tensor or Tuple[str, torch.Tensor]): The indices of seed nodes to start sampling from. It is of type `torch.LongTensor` for homogeneous graphs. If set to `None` for homogeneous settings, all nodes will be considered. In heterogeneous graphs, this flag must be passed in as a tuple that holds the node type and node indices. (default: `None`) num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load (default: ``1``). pin_memory_device (str, optional): The target device that the sampled results should be copied to. If set to ``None``, the device is inferred based off of (got by ``gigl.distributed.utils.device.get_available_device``). Which uses the local_process_rank and torch.cuda.device_count() to assign the device. If cuda is not available, the cpu device will be used. (default: ``None``). worker_concurrency (int): The max sampling concurrency for each sampling worker. Load testing has showed that setting worker_concurrency to 4 yields the best performance for sampling. Although, you may whish to explore higher/lower settings when performance tuning. (default: `4`). channel_size (int or str): The shared-memory buffer size (bytes) allocated for the channel. Can be modified for performance tuning; a good starting point is: ``num_workers * 64MB`` (default: "4GB"). process_start_gap_seconds (float): Delay between each process for initializing neighbor loader. At large scales, it is recommended to set this value to be between 60 and 120 seconds -- otherwise multiple processes may attempt to initialize dataloaders at overlapping times, which can cause CPU memory OOM. num_cpu_threads (Optional[int]): Number of cpu threads PyTorch should use for CPU training/inference neighbor loading; on top of the per process parallelism. Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. # We set to `True` as we don't need to cleanup right away, and this will get set # to `False` in super().__init__()` e.g. # https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_loader.py#L125C1-L126C1 self._shutdowned = True node_world_size: int node_rank: int rank: int world_size: int local_rank: int local_world_size: int master_ip_address: str should_cleanup_distributed_context: bool = False if context: assert ( local_process_world_size is not None ), "context: DistributedContext provided, so local_process_world_size must be provided." assert ( local_process_rank is not None ), "context: DistributedContext provided, so local_process_rank must be provided." master_ip_address = context.main_worker_ip_address node_world_size = context.global_world_size node_rank = context.global_rank local_world_size = local_process_world_size local_rank = local_process_rank rank = node_rank * local_world_size + local_rank world_size = node_world_size * local_world_size if not torch.distributed.is_initialized(): logger.info( "process group is not available, trying to torch.distributed.init_process_group to communicate necessary setup information." ) should_cleanup_distributed_context = True logger.info( f"Initializing process group with master ip address: {master_ip_address}, rank: {rank}, world size: {world_size}, local_rank: {local_rank}, local_world_size: {local_world_size}." ) torch.distributed.init_process_group( backend="gloo", # We just default to gloo for this temporary process group init_method=f"tcp://{master_ip_address}:{DEFAULT_MASTER_INFERENCE_PORT}", rank=rank, world_size=world_size, ) else: assert ( torch.distributed.is_initialized() ), f"context: DistributedContext is None, so process group must be initialized before constructing this object {self.__class__.__name__}." world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() rank_ip_addresses = gigl.distributed.utils.get_internal_ip_from_all_ranks() master_ip_address = rank_ip_addresses[0] count_ranks_per_ip_address = Counter(rank_ip_addresses) local_world_size = count_ranks_per_ip_address[master_ip_address] for rank_ip_address, count in count_ranks_per_ip_address.items(): if count != local_world_size: raise ValueError( f"All ranks must have the same number of processes, but found {count} processes for rank {rank} on ip {rank_ip_address}, expected {local_world_size}." + f"count_ranks_per_ip_address = {count_ranks_per_ip_address}" ) node_world_size = len(count_ranks_per_ip_address) local_rank = rank % local_world_size node_rank = rank // local_world_size del ( context, local_process_rank, local_process_world_size, ) # delete deprecated vars so we don't accidentally use them. device = ( pin_memory_device if pin_memory_device else gigl.distributed.utils.get_available_device( local_process_rank=local_rank ) ) logger.info( f"Dataset Building started on {node_rank} of {node_world_size} nodes, using following node as main: {master_ip_address}" ) if input_nodes is None: if dataset.node_ids is None: raise ValueError( "Dataset must have node ids if input_nodes are not provided." ) if isinstance(dataset.node_ids, abc.Mapping): raise ValueError( f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}" ) input_nodes = dataset.node_ids if isinstance(num_neighbors, abc.Mapping): # TODO(kmonte): We should enable this. We have two blockers: # 1. We need to treat `EdgeType` as a proper tuple, not the GiGL`EdgeType`. # 2. There are (likely) some GLT bugs around https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_neighbor_sampler.py#L317-L318 # Where if num_neighbors is a dict then we index into it improperly. if not isinstance(dataset.graph, abc.Mapping): raise ValueError( "When num_neighbors is a dict, the dataset must be heterogeneous." ) if num_neighbors.keys() != dataset.graph.keys(): raise ValueError( f"num_neighbors must have all edge types in the graph, received: {num_neighbors.keys()} with for graph with edge types {dataset.graph.keys()}" ) hops = len(next(iter(num_neighbors.values()))) if not all(len(fanout) == hops for fanout in num_neighbors.values()): raise ValueError( f"num_neighbors must be a dict of edge types with the same number of hops. Received: {num_neighbors}" ) # Determines if the node ids passed in are heterogeneous or homogeneous. self._is_labeled_heterogeneous = False if isinstance(input_nodes, torch.Tensor): node_ids = input_nodes # If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting, # if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE. if isinstance(dataset.node_ids, abc.Mapping): if ( len(dataset.node_ids) == 1 and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids ): node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE self._is_labeled_heterogeneous = True num_neighbors = patch_fanout_for_sampling( dataset.get_edge_types(), num_neighbors ) else: raise ValueError( f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" ) else: node_type = None else: node_type, node_ids = input_nodes assert isinstance( dataset.node_ids, abc.Mapping ), "Dataset must be heterogeneous if provided input nodes are a tuple." num_neighbors = patch_fanout_for_sampling( dataset.get_edge_types(), num_neighbors ) curr_process_nodes = shard_nodes_by_process( input_nodes=node_ids, local_process_rank=local_rank, local_process_world_size=local_world_size, ) input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize # the memory overhead and CPU contention. logger.info( f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} using device: {device}" ) should_use_cpu_workers = device.type == "cpu" if should_use_cpu_workers and num_cpu_threads is None: logger.info( "Using CPU workers, but found num_cpu_threads to be None. " f"Will default setting num_cpu_threads to {DEFAULT_NUM_CPU_THREADS}." ) num_cpu_threads = DEFAULT_NUM_CPU_THREADS neighbor_loader_ports = gigl.distributed.utils.get_free_ports_from_master_node( num_ports=local_world_size ) neighbor_loader_port_for_current_rank = neighbor_loader_ports[local_rank] gigl.distributed.utils.init_neighbor_loader_worker( master_ip_address=master_ip_address, local_process_rank=local_rank, local_process_world_size=local_world_size, rank=node_rank, world_size=node_world_size, master_worker_port=neighbor_loader_port_for_current_rank, device=device, should_use_cpu_workers=should_use_cpu_workers, # Lever to explore tuning for CPU based inference num_cpu_threads=num_cpu_threads, process_start_gap_seconds=process_start_gap_seconds, ) logger.info( f"Finished initializing neighbor loader worker: {local_rank}/{local_world_size}" ) # Sets up worker options for the dataloader dist_sampling_ports = gigl.distributed.utils.get_free_ports_from_master_node( num_ports=local_world_size ) dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank] worker_options = MpDistSamplingWorkerOptions( num_workers=num_workers, worker_devices=[torch.device("cpu") for _ in range(num_workers)], worker_concurrency=worker_concurrency, # Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group # need to be connected. Thus, we need master ip address and master port to # initate the connection. # Note that different groups of workers are independent, and thus # the sampling processes in different groups should be independent, and should # use different master ports. master_addr=master_ip_address, master_port=dist_sampling_port_for_current_rank, # Load testing show that when num_rpc_threads exceed 16, the performance # will degrade. num_rpc_threads=min(dataset.num_partitions, 16), rpc_timeout=600, channel_size=channel_size, pin_memory=device.type == "cuda", ) sampling_config = SamplingConfig( sampling_type=SamplingType.NODE, num_neighbors=num_neighbors, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, with_edge=True, collect_features=True, with_neg=False, with_weight=False, edge_dir=dataset.edge_dir, seed=None, # it's actually optional - None means random. ) if should_cleanup_distributed_context and torch.distributed.is_initialized(): logger.info( f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." ) torch.distributed.destroy_process_group() super().__init__(dataset, input_data, sampling_config, device, worker_options) def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = super()._collate_fn(msg) if isinstance(data, HeteroData): data = strip_label_edges(data) if self._is_labeled_heterogeneous: data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data) return data