Source code for gigl.distributed.dist_ablp_neighborloader

from collections import abc, defaultdict
from itertools import count
from typing import Callable, Optional, Union

import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import (
    MpDistSamplingWorkerOptions,
    RemoteDistSamplingWorkerOptions,
)
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import EdgeType

import gigl.distributed.utils
from gigl.common.logger import Logger
from gigl.distributed.base_dist_loader import BaseDistLoader
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.dist_ppr_sampler import (
    PPR_EDGE_INDEX_METADATA_KEY,
    PPR_WEIGHT_METADATA_KEY,
)
from gigl.distributed.dist_sampling_producer import DistSamplingProducer
from gigl.distributed.graph_store.dist_server import DistServer
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.sampler import (
    NEGATIVE_LABEL_METADATA_KEY,
    POSITIVE_LABEL_METADATA_KEY,
    ABLPNodeSamplerInput,
)
from gigl.distributed.sampler_options import (
    PPRSamplerOptions,
    SamplerOptions,
    resolve_sampler_options,
)
from gigl.distributed.utils.neighborloader import (
    DatasetSchema,
    SamplingClusterSetup,
    attach_ppr_outputs,
    extract_edge_type_metadata,
    extract_metadata,
    labeled_to_homogeneous,
    set_missing_features,
    shard_nodes_by_process,
    strip_label_edges,
    strip_non_ppr_edge_types,
)
from gigl.src.common.types.graph_data import (
    NodeType,  # TODO (mkolodner-sc): Change to use torch_geometric.typing
)
from gigl.types.graph import (
    DEFAULT_HOMOGENEOUS_EDGE_TYPE,
    DEFAULT_HOMOGENEOUS_NODE_TYPE,
    label_edge_type_to_message_passing_edge_type,
    message_passing_to_negative_label,
    message_passing_to_positive_label,
    reverse_edge_type,
    select_label_edge_types,
)
from gigl.utils.data_splitters import get_labels_for_anchor_nodes
from gigl.utils.sampling import ABLPInputNodes

[docs] logger = Logger()
[docs] class DistABLPLoader(BaseDistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. # NOTE: This is per-class, not per-instance. _counter = count(0) def __init__( self, dataset: Union[DistDataset, RemoteDistDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ Union[ torch.Tensor, tuple[NodeType, torch.Tensor], # Graph Store mode inputs dict[int, ABLPInputNodes], ] ] = None, supervision_edge_type: Optional[Union[EdgeType, list[EdgeType]]] = None, num_workers: int = 1, batch_size: int = 1, pin_memory_device: Optional[torch.device] = None, worker_concurrency: int = 4, prefetch_size: Optional[int] = None, channel_size: str = "4GB", process_start_gap_seconds: float = 60.0, max_concurrent_producer_inits: Optional[int] = None, num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, sampler_options: Optional[SamplerOptions] = None, context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this non_blocking_transfers: bool = True, ): """ Neighbor loader for Anchor Based Link Prediction (ABLP) tasks. Note that for this class, the dataset must *always* be heterogeneous, as we need separate edge types for positive and negative labels. By default, the loader will return {py:class} `torch_geometric.data.HeteroData` (heterogeneous) objects, but will return a {py:class}`torch_geometric.data.Data` (homogeneous) object if the dataset is "labeled homogeneous". The following fields may also be present: - `y_positive`: `dict[int, torch.Tensor]` mapping from local anchor node id to a tensor of positive label node ids. - `y_negative`: (Optional) `dict[int, torch.Tensor]` mapping from local anchor node id to a tensor of negative label node ids. This will only be present if the supervision edge type has negative labels. NOTE: for both y_positive, and y_negative, the values represented in both the key and value of the dicts are the *local* node ids of the sampled nodes, not the global node ids. In order to get the global node ids, you can use the `node` field of the Data/HeteroData object. e.g. global_positive_node_id_labels = data.node[data.y_positive[local_anchor_node_id]]. The underlying graph engine may also add the following fields to the output Data object: - num_sampled_nodes: If heterogeneous. a dictionary mapping from node type to the number of sampled nodes for that type, by hop. if homogeneous, a tensor the number of sampled nodes, by hop. - num_sampled_edges: If heterogeneous, a dictionary mapping from edge type to the number of sampled edges for that type, by hop. If homogeneous, a tensor denoting the number of sampled edges, by hop. Let's use the following homogeneous graph (https://is.gd/a8DK15) as an example: 0 -> 1 [label="Positive example" color="green"] 0 -> 2 [label="Negative example" color="red"] 0 -> {3, 4} 3 -> {5, 6} 4 -> {7, 8} 1 -> 9 # shouldn't be sampled 2 -> 10 # shouldn't be sampled For sampling around node `0`, the fields on the output Data object will be: - `y_positive`: {0: torch.tensor([1])} # 1 is the only positive label for node 0 - `y_negative`: {0: torch.tensor([2])} # 2 is the only negative label for node 0 NOTE: both label fields will instead be `dict[EdgeType, dict[int, torch.Tensor]]` if multiple supervision edge types are provided. e.g. if there are supervision edge types: (a, to, b) and (a, to, c), then the label fields could be: - `y_positive`: {(a, to, b): {0: torch.tensor([1])}, (a, to, c): {0: torch.tensor([2])}} - `y_negative`: {(a, to, b): {0: torch.tensor([3])}, (a, to, c): {0: torch.tensor([4])}} Args: dataset (Union[DistDataset, RemoteDistDataset]): The dataset to sample from. If this is a `RemoteDistDataset`, then we are in "Graph Store" mode. num_neighbors (list[int] or dict[tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. If ``KHopNeighborSamplerOptions`` is also provided, they must match. input_nodes: Indices of seed nodes to start sampling from. For Colocated mode: `torch.Tensor` or `tuple[NodeType, torch.Tensor]`. If set to `None` for homogeneous settings, all nodes will be considered. In heterogeneous graphs, this flag must be passed in as a tuple that holds the node type and node indices. NOTE: We intend to migrate colocated mode to have a similar input format to Graph Store mode in the future. We want to do this so that users can easily control labels per anchor. For Graph Store mode: `dict[int, ABLPInputNodes]` Maps server_rank to an ABLPInputNodes dataclass containing anchor nodes, positive labels, and negative labels with explicit node type and edge type info. This is the return type of `RemoteDistDataset.fetch_ablp_input()`. supervision_edge_type (Optional[Union[EdgeType, list[EdgeType]]]): The edge type(s) to use for supervision. For Colocated mode: Must be None iff the dataset is labeled homogeneous. If set to a single EdgeType, the positive and negative labels will be stored in the `y_positive` and `y_negative` fields of the Data object. If set to a list of EdgeTypes, the positive and negative labels will be stored in the `y_positive` and `y_negative` fields of the Data object, with the key being the EdgeType. (default: `None`) For Graph Store mode: Must not be provided (must be None). The supervision edge types are inferred from the label edge type keys in ABLPInputNodes. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load (default: ``1``). pin_memory_device (str, optional): The target device that the sampled results should be copied to. If set to ``None``, the device is inferred based off of (got by ``gigl.distributed.utils.device.get_available_device``). Which uses the local_process_rank and torch.cuda.device_count() to assign the device. If cuda is not available, the cpu device will be used. (default: ``None``). worker_concurrency (int): The max sampling concurrency for each sampling worker. Load testing has showed that setting worker_concurrency to 4 yields the best performance for sampling. Although, you may whish to explore higher/lower settings when performance tuning. (default: `4`). prefetch_size (Optional[int]): Max number of sampled messages to prefetch on the client side, per server. Only applies to Graph Store mode (remote workers). Lower values reduce server-side RPC thread contention when multiple loaders are active concurrently. (default: ``None``). If supplied and not it Graph Store mode, an error will be raised. channel_size (int or str): The shared-memory buffer size (bytes) allocated for the channel. Can be modified for performance tuning; a good starting point is: ``num_workers * 64MB`` (default: "4GB"). process_start_gap_seconds (float): Delay between each process for initializing neighbor loader. In colocated mode, each process sleeps ``local_rank * process_start_gap_seconds`` before initializing. In graph store mode, leader ranks are grouped into batches of ``max_concurrent_producer_inits`` and each batch sleeps ``batch_index * process_start_gap_seconds`` before dispatching RPCs. max_concurrent_producer_inits (int): Maximum number of leader ranks that may dispatch create-producer RPCs concurrently in graph store mode. Leaders are grouped into batches of this size; each batch is staggered by ``process_start_gap_seconds``. Only applies to graph store mode. Defaults to ``None`` (no staggering). num_cpu_threads (Optional[int]): Number of cpu threads PyTorch should use for CPU training/inference neighbor loading; on top of the per process parallelism. Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Defaults to `KHopNeighborSamplerOptions`, which will use the num_neighbors argument to instantiate the sampler. context (deprecated - will be removed soon) (Optional[DistributedContext]): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon) (int): The total number of processes within a node. non_blocking_transfers (bool): If True (default), batch-transfers all sampled tensors to the target CUDA device using non-blocking copies before collation, which can overlap data transfer with computation when source tensors reside in pinned memory. If False, the bulk transfer is skipped and GLT's default (blocking) device placement is used instead. See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html for background on pinned memory and non-blocking transfers. """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. self._shutdowned = True sampler_options = resolve_sampler_options(num_neighbors, sampler_options) # Determine sampling cluster setup based on dataset type if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE if supervision_edge_type is not None: raise ValueError( "supervision_edge_type must not be provided when using Graph Store mode. " "The supervision edge types are inferred from the ABLPInputNodes label keys in input_nodes." ) # self._supervision_edge_types will be set in _setup_for_graph_store else: self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED if supervision_edge_type is None: self._supervision_edge_types: list[EdgeType] = [ DEFAULT_HOMOGENEOUS_EDGE_TYPE ] elif isinstance(supervision_edge_type, list): if not supervision_edge_type: raise ValueError( "supervision_edge_type must be a non-empty list when providing multiple supervision edge types." ) self._supervision_edge_types = supervision_edge_type else: self._supervision_edge_types = [supervision_edge_type] if prefetch_size is not None: raise ValueError( f"prefetch_size must be None when using Colocated mode, received {prefetch_size}" ) if max_concurrent_producer_inits is not None: raise ValueError( f"max_concurrent_producer_inits must be None when using Colocated mode, received {max_concurrent_producer_inits}" ) logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") del supervision_edge_type self._instance_count = next(self._counter) # Resolve distributed context runtime = BaseDistLoader.resolve_runtime( context, local_process_rank, local_process_world_size ) del context, local_process_rank, local_process_world_size device = ( pin_memory_device if pin_memory_device else gigl.distributed.utils.get_available_device( local_process_rank=runtime.local_rank ) )
[docs] self.to_device = device
# Mode-specific setup if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance( dataset, DistDataset ), "When using colocated mode, dataset must be a DistDataset." # Validate input_nodes type for colocated mode if isinstance(input_nodes, dict): raise ValueError( f"When using Colocated mode, input_nodes must be of type " f"(torch.Tensor | tuple[NodeType, torch.Tensor] | None), " f"received {type(input_nodes)}" ) setup_info = self._setup_for_colocated( input_nodes=input_nodes, dataset=dataset, local_rank=runtime.local_rank, local_world_size=runtime.local_world_size, device=device, master_ip_address=runtime.master_ip_address, node_rank=runtime.node_rank, node_world_size=runtime.node_world_size, num_workers=num_workers, worker_concurrency=worker_concurrency, channel_size=channel_size, num_cpu_threads=num_cpu_threads, ) sampler_input: Union[ ABLPNodeSamplerInput, list[ABLPNodeSamplerInput] ] = setup_info[0] worker_options: Union[ MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions ] = setup_info[1] dataset_schema: DatasetSchema = setup_info[2] else: # Graph Store mode assert isinstance( dataset, RemoteDistDataset ), "When using Graph Store mode, dataset must be a RemoteDistDataset." # Validate input_nodes type for Graph Store mode if not isinstance(input_nodes, dict): raise ValueError( f"When using Graph Store mode, input_nodes must be of type " f"dict[int, ABLPInputNodes], " f"received {type(input_nodes)}" ) if prefetch_size is None: logger.info(f"prefetch_size is not provided, using default of 4") prefetch_size = 4 ( sampler_input, worker_options, dataset_schema, ) = self._setup_for_graph_store( input_nodes=input_nodes, dataset=dataset, num_workers=num_workers, worker_concurrency=worker_concurrency, channel_size=channel_size, prefetch_size=prefetch_size, ) # Cleanup temporary process group if needed if ( runtime.should_cleanup_distributed_context and torch.distributed.is_initialized() ): logger.info( f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." ) torch.distributed.destroy_process_group() # Create SamplingConfig (with patched fanout) sampling_config = BaseDistLoader.create_sampling_config( num_neighbors=num_neighbors, dataset_schema=dataset_schema, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, ) # Build the producer: a pre-constructed producer for colocated mode, # or an RPC callable for graph store mode. if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) channel = BaseDistLoader.create_colocated_channel(worker_options) producer: Union[ DistSamplingProducer, Callable[..., int] ] = DistSamplingProducer( data=dataset, sampler_input=sampler_input, sampling_config=sampling_config, worker_options=worker_options, channel=channel, sampler_options=sampler_options, ) else: producer = DistServer.create_sampling_producer # Call base class — handles metadata storage and connection initialization # (including staggered init for colocated mode). super().__init__( dataset=dataset, sampler_input=sampler_input, dataset_schema=dataset_schema, worker_options=worker_options, sampling_config=sampling_config, device=device, runtime=runtime, producer=producer, sampler_options=sampler_options, process_start_gap_seconds=process_start_gap_seconds, max_concurrent_producer_inits=max_concurrent_producer_inits, non_blocking_transfers=non_blocking_transfers, ) def _setup_for_colocated( self, input_nodes: Optional[ Union[ torch.Tensor, tuple[NodeType, torch.Tensor], ] ], dataset: DistDataset, local_rank: int, local_world_size: int, device: torch.device, master_ip_address: str, node_rank: int, node_world_size: int, num_workers: int, worker_concurrency: int, channel_size: str, num_cpu_threads: Optional[int], ) -> tuple[ABLPNodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema]: """ Setup method for colocated (non-Graph Store) mode. Args: input_nodes: Input nodes for sampling (tensor or tuple of node type and tensor). dataset: The DistDataset to sample from. local_rank: Local rank of the current process. local_world_size: Total number of processes on this machine. device: Target device for sampled data. master_ip_address: IP address of the master node. node_rank: Rank of the current machine. node_world_size: Total number of machines. num_workers: Number of sampling workers. worker_concurrency: Max sampling concurrency per worker. channel_size: Size of shared memory channel. num_cpu_threads: Number of CPU threads for PyTorch. Returns: Tuple of (ABLPNodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema). """ # Validate input format - should not be Graph Store format if isinstance(input_nodes, abc.Mapping): raise ValueError( f"When using Colocated mode, input_nodes must be of type (torch.Tensor | tuple[NodeType, torch.Tensor]), " f"received {type(input_nodes)}" ) elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping): raise ValueError( f"When using Colocated mode, input_nodes must be of type (torch.Tensor | tuple[NodeType, torch.Tensor]), " f"received tuple with second element of type {type(input_nodes[1])}" ) if not isinstance(dataset.graph, abc.Mapping): raise ValueError( f"The dataset must be heterogeneous for ABLP. Received dataset with graph of type: {type(dataset.graph)}" ) is_homogeneous_with_labeled_edge_type: bool = True if isinstance(input_nodes, tuple): if self._supervision_edge_types == [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: raise ValueError( "When using heterogeneous ABLP, you must provide supervision_edge_types." ) is_homogeneous_with_labeled_edge_type = False anchor_node_type, anchor_node_ids = input_nodes # TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if # this assumption is no longer valid and/or is too opinionated for supervision_edge_type in self._supervision_edge_types: assert ( supervision_edge_type[0] == anchor_node_type ), f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \ got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}" if dataset.edge_dir == "in": self._supervision_edge_types = [ reverse_edge_type(supervision_edge_type) for supervision_edge_type in self._supervision_edge_types ] elif isinstance(input_nodes, torch.Tensor): if self._supervision_edge_types != [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: raise ValueError( f"Expected supervision edge type to be None for homogeneous input nodes, got {self._supervision_edge_types}" ) anchor_node_ids = input_nodes anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE elif input_nodes is None: if dataset.node_ids is None: raise ValueError( "Dataset must have node ids if input_nodes are not provided." ) if isinstance(dataset.node_ids, abc.Mapping): raise ValueError( f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}" ) if self._supervision_edge_types != [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: raise ValueError( f"Expected supervision edge type to be None for homogeneous input nodes, got {self._supervision_edge_types}" ) anchor_node_ids = dataset.node_ids anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE else: raise ValueError(f"Unexpected input_nodes type: {type(input_nodes)}") missing_edge_types = set(self._supervision_edge_types) - set( dataset.graph.keys() ) if missing_edge_types: raise ValueError( f"Missing edge types in dataset: {missing_edge_types}. Edge types in dataset: {dataset.graph.keys()}" ) # Type narrowing - anchor_node_ids is always a Tensor in colocated mode assert isinstance(anchor_node_ids, torch.Tensor) if len(anchor_node_ids.shape) != 1: raise ValueError( f"input_nodes must be a 1D tensor, got {anchor_node_ids.shape}." ) curr_process_nodes = shard_nodes_by_process( input_nodes=anchor_node_ids, local_process_rank=local_rank, local_process_world_size=local_world_size, ) self._positive_label_edge_types: list[EdgeType] = [] self._negative_label_edge_types: list[EdgeType] = [] positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} for supervision_edge_type in self._supervision_edge_types: ( positive_label_edge_type, negative_label_edge_type, ) = select_label_edge_types(supervision_edge_type, dataset.graph.keys()) self._positive_label_edge_types.append(positive_label_edge_type) if negative_label_edge_type is not None: self._negative_label_edge_types.append(negative_label_edge_type) positive_labels, negative_labels = get_labels_for_anchor_nodes( dataset=dataset, node_ids=curr_process_nodes, positive_label_edge_type=positive_label_edge_type, negative_label_edge_type=negative_label_edge_type, ) positive_labels_by_label_edge_type[ positive_label_edge_type ] = positive_labels if negative_label_edge_type is not None and negative_labels is not None: negative_labels_by_label_edge_type[ negative_label_edge_type ] = negative_labels sampler_input = ABLPNodeSamplerInput( node=curr_process_nodes, input_type=anchor_node_type, positive_label_by_edge_types=positive_labels_by_label_edge_type, negative_label_by_edge_types=negative_labels_by_label_edge_type, ) BaseDistLoader.initialize_colocated_sampling_worker( local_rank=local_rank, local_world_size=local_world_size, node_rank=node_rank, node_world_size=node_world_size, master_ip_address=master_ip_address, device=device, num_cpu_threads=num_cpu_threads, ) # Sets up worker options for the dataloader dist_sampling_ports = gigl.distributed.utils.get_free_ports_from_master_node( num_ports=local_world_size ) dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank] worker_options = BaseDistLoader.create_colocated_worker_options( dataset_num_partitions=dataset.num_partitions, num_workers=num_workers, worker_concurrency=worker_concurrency, master_ip_address=master_ip_address, master_port=dist_sampling_port_for_current_rank, channel_size=channel_size, pin_memory=device.type == "cuda", ) edge_types = list(dataset.graph.keys()) return ( sampler_input, worker_options, DatasetSchema( is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type, edge_types=edge_types, node_feature_info=dataset.node_feature_info, edge_feature_info=dataset.edge_feature_info, edge_dir=dataset.edge_dir, ), ) def _setup_for_graph_store( self, input_nodes: dict[int, ABLPInputNodes], dataset: RemoteDistDataset, num_workers: int, worker_concurrency: int, channel_size: str, prefetch_size: int, ) -> tuple[ list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema ]: """ Setup method for Graph Store mode. Args: input_nodes: ABLP input from RemoteDistDataset.fetch_ablp_input(). Maps server_rank to ABLPInputNodes containing anchor nodes, positive/negative labels with explicit node type and edge type information. dataset: The RemoteDistDataset to sample from. num_workers: Number of sampling workers. worker_concurrency: Max sampling concurrency per worker. channel_size: Size of the remote shared-memory buffer. prefetch_size: Max prefetched sampled messages per server on client side. Returns: Tuple of (list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema). """ node_feature_info = dataset.fetch_node_feature_info() edge_feature_info = dataset.fetch_edge_feature_info() edge_types = dataset.fetch_edge_types() compute_rank = torch.distributed.get_rank() worker_key = ( f"compute_ablp_loader_rank_{compute_rank}_worker_{self._instance_count}" ) logger.info(f"rank: {compute_rank}, worker_key: {worker_key}") worker_options = BaseDistLoader.create_graph_store_worker_options( dataset=dataset, compute_rank=compute_rank, worker_key=worker_key, num_workers=num_workers, worker_concurrency=worker_concurrency, channel_size=channel_size, prefetch_size=prefetch_size, ) logger.info( f"Rank {torch.distributed.get_rank()}! init for sampling rpc: " f"tcp://{worker_options.master_addr}:{worker_options.master_port}" ) # Validate server ranks servers = input_nodes.keys() if len(servers) > 0: if ( max(servers) >= dataset.cluster_info.num_storage_nodes or min(servers) < 0 ): raise ValueError( f"When using Graph Store mode, the server ranks must be in range " f"[0, {dataset.cluster_info.num_storage_nodes}), " f"received inputs for servers: {list(servers)}" ) # Extract node type and label edge types from the ABLPInputNodes dataclass. # All entries should have the same anchor_node_type and edge type keys. first_input = next(iter(input_nodes.values())) input_type = first_input.anchor_node_type is_homogeneous_with_labeled_edge_type = ( input_type == DEFAULT_HOMOGENEOUS_NODE_TYPE ) # Extract supervision edge types and derive label edge types from the # ABLPInputNodes.labels dict (keyed by supervision edge type). self._supervision_edge_types = list(first_input.labels.keys()) has_negatives = any(neg is not None for _, neg in first_input.labels.values()) self._positive_label_edge_types = [ message_passing_to_positive_label(et) for et in self._supervision_edge_types ] self._negative_label_edge_types = ( [ message_passing_to_negative_label(et) for et in self._supervision_edge_types ] if has_negatives else [] ) # Graph Store mode currently only supports a single supervision edge type, # so the labels dict must have exactly 1 entry. if len(self._supervision_edge_types) != 1: raise ValueError( f"Graph Store mode currently only supports a single supervision edge type, " f"but ABLPInputNodes.labels has {len(self._supervision_edge_types)} " f"entries: {self._supervision_edge_types}" ) logger.info(f"Positive label edge types: {self._positive_label_edge_types}") logger.info(f"Negative label edge types: {self._negative_label_edge_types}") # Convert from ABLPInputNodes to list of ABLPNodeSamplerInput (one per server) input_data: list[ABLPNodeSamplerInput] = [] for server_rank in range(dataset.cluster_info.num_storage_nodes): positive_label_by_edge_type: dict[EdgeType, torch.Tensor] = {} negative_label_by_edge_type: dict[EdgeType, torch.Tensor] = {} if server_rank in input_nodes: ablp_input_nodes = input_nodes[server_rank] anchors = ablp_input_nodes.anchor_nodes for supervision_edge_type, ( positive_labels, negative_labels, ) in ablp_input_nodes.labels.items(): positive_label_by_edge_type[ message_passing_to_positive_label(supervision_edge_type) ] = positive_labels if negative_labels is not None: negative_label_by_edge_type[ message_passing_to_negative_label(supervision_edge_type) ] = negative_labels else: # Empty input for servers with no data for this rank anchors = torch.empty(0, dtype=torch.long) positive_label_by_edge_type = { et: torch.empty(0, 0, dtype=torch.long) for et in self._positive_label_edge_types } if has_negatives: negative_label_by_edge_type = { et: torch.empty(0, 0, dtype=torch.long) for et in self._negative_label_edge_types } logger.info( f"Rank: {torch.distributed.get_rank()}! Building ABLPNodeSamplerInput for server rank: {server_rank} " f"with input type: {input_type}. anchors: {anchors.shape}, " f"positive_labels edge types: {list(positive_label_by_edge_type.keys())}, " f"negative_labels edge types: {list(negative_label_by_edge_type.keys())}" ) ablp_input = ABLPNodeSamplerInput( node=anchors, input_type=input_type, positive_label_by_edge_types=positive_label_by_edge_type, negative_label_by_edge_types=negative_label_by_edge_type, ) input_data.append(ablp_input) return ( input_data, worker_options, DatasetSchema( is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type, edge_types=edge_types, node_feature_info=node_feature_info, edge_feature_info=edge_feature_info, edge_dir=dataset.fetch_edge_dir(), ), ) def _set_labels( self, data: Union[Data, HeteroData], positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor], negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor], ) -> Union[Data, HeteroData]: """ Sets the labels and relevant fields in the torch_geometric Data object, converting the global node ids for labels to their local index. Removes inserted supervision edge type from the data variables, since this is an implementation detail and should not be exposed in the final HeteroData/Data object. Args: data (Union[Data, HeteroData]): Graph to provide labels for positive_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Dict[positive label edge type, label ID tensor], where the ith row of the tensor corresponds to the ith anchor node ID. negative_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Dict[negative label edge type, label ID tensor], where the ith row of the tensor corresponds to the ith anchor node ID. Returns: Union[Data, HeteroData]: torch_geometric HeteroData/Data object with the filtered edge fields and labels set as properties of the instance """ # shape [N], where N is the number of nodes in the subgraph, and local_node_to_global_node[i] gives the global node id for local node id `i` node_type_to_local_node_to_global_node: dict[NodeType, torch.Tensor] = {} if isinstance(data, HeteroData): for e_type in self._supervision_edge_types: node_type_to_local_node_to_global_node[e_type[0]] = data[e_type[0]].node node_type_to_local_node_to_global_node[e_type[2]] = data[e_type[2]].node else: node_type_to_local_node_to_global_node[ DEFAULT_HOMOGENEOUS_NODE_TYPE ] = data.node output_positive_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict( dict ) output_negative_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict( dict ) # We always have supervision edge types of the form (anchor_node_type, to, supervision_node_type) # So we can index into the edge type accordingly. edge_index = 2 for edge_type, label_tensor in positive_labels_by_label_edge_type.items(): for local_anchor_node_id in range(label_tensor.size(0)): positive_mask = ( node_type_to_local_node_to_global_node[ edge_type[edge_index] ].unsqueeze(1) == label_tensor[local_anchor_node_id] ) # shape [N, P], where N is the number of nodes and P is the number of positive labels for the current anchor node # Gets the indexes of the items in local_node_to_global_node which match any of the positive labels for the current anchor node output_positive_labels[ label_edge_type_to_message_passing_edge_type(edge_type) ][local_anchor_node_id] = torch.nonzero(positive_mask)[:, 0].to( self.to_device ) # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the positive labels for the current anchor node for edge_type, label_tensor in negative_labels_by_label_edge_type.items(): for local_anchor_node_id in range(label_tensor.size(0)): negative_mask = ( node_type_to_local_node_to_global_node[ edge_type[edge_index] ].unsqueeze(1) == label_tensor[local_anchor_node_id] ) # shape [N, M], where N is the number of nodes and M is the number of negative labels for the current anchor node # Gets the indexes of the items in local_node_to_global_node which match any of the negative labels for the current anchor node output_negative_labels[ label_edge_type_to_message_passing_edge_type(edge_type) ][local_anchor_node_id] = torch.nonzero(negative_mask)[:, 0].to( self.to_device ) # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the negative labels for the current anchor node if not output_positive_labels: raise ValueError("No positive labels were found in the data!") elif len(output_positive_labels) == 1: data.y_positive = next(iter(output_positive_labels.values())) else: data.y_positive = output_positive_labels if len(output_negative_labels) == 1: data.y_negative = next(iter(output_negative_labels.values())) elif len(output_negative_labels) > 0: data.y_negative = output_negative_labels return data def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: # extract_metadata separates #META. keys from the message to work # around a GLT bug in to_hetero_data. extract_edge_type_metadata then # pulls out labels by prefix. # TODO (mkolodner-sc): Remove the need to extract metadata once GLT's `to_hetero_data` function is fixed metadata, stripped_msg = extract_metadata(msg, self.to_device) data = super()._collate_fn(stripped_msg) data = set_missing_features( data=data, node_feature_info=self._node_feature_info, edge_feature_info=self._edge_feature_info, device=self.to_device, ) matched, metadata = extract_edge_type_metadata( metadata=metadata, prefixes=[POSITIVE_LABEL_METADATA_KEY, NEGATIVE_LABEL_METADATA_KEY], ) positive_labels = matched[POSITIVE_LABEL_METADATA_KEY] negative_labels = matched[NEGATIVE_LABEL_METADATA_KEY] # When edge_dir="in", GLT internally swaps src/dst on all edge types during sampling, # so the sampler encodes label edge types in their reversed (incoming) form. # We reverse them back here to restore the original outward edge type that # _set_labels and downstream code expect. if self.edge_dir == "in": positive_labels = { reverse_edge_type(et): v for et, v in positive_labels.items() } negative_labels = { reverse_edge_type(et): v for et, v in negative_labels.items() } if isinstance(data, HeteroData): data = strip_label_edges(data) if self._is_homogeneous_with_labeled_edge_type: if len(self._supervision_edge_types) != 1: raise ValueError( f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}" ) data = labeled_to_homogeneous(self._supervision_edge_types[0], data) data = self._set_labels(data, positive_labels, negative_labels) if isinstance(self._sampler_options, PPRSamplerOptions): matched_ppr, metadata = extract_edge_type_metadata( metadata=metadata, prefixes=[PPR_EDGE_INDEX_METADATA_KEY, PPR_WEIGHT_METADATA_KEY], ) ppr_edge_indices = matched_ppr[PPR_EDGE_INDEX_METADATA_KEY] ppr_weights = matched_ppr[PPR_WEIGHT_METADATA_KEY] attach_ppr_outputs(data, ppr_edge_indices, ppr_weights) if isinstance(data, HeteroData): data = strip_non_ppr_edge_types(data, set(ppr_edge_indices.keys())) # Attach any remaining metadata (e.g. custom user-defined keys) directly onto the # data object so downstream code can access them via attribute lookup. for key, value in metadata.items(): data[key] = value return data