from collections import Counter, abc
from typing import Optional, Union
import torch
from graphlearn_torch.channel import SampleMessage, ShmChannel
from graphlearn_torch.distributed import (
DistLoader,
MpDistSamplingWorkerOptions,
get_context,
)
from graphlearn_torch.sampler import SamplingConfig, SamplingType
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import EdgeType
import gigl.distributed.utils
from gigl.common.logger import Logger
from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_link_prediction_dataset import DistLinkPredictionDataset
from gigl.distributed.dist_sampling_producer import DistSamplingProducer
from gigl.distributed.distributed_neighborloader import DEFAULT_NUM_CPU_THREADS
from gigl.distributed.sampler import ABLPNodeSamplerInput
from gigl.distributed.utils.neighborloader import (
labeled_to_homogeneous,
patch_fanout_for_sampling,
shard_nodes_by_process,
strip_label_edges,
)
from gigl.src.common.types.graph_data import (
NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing
)
from gigl.types.graph import (
DEFAULT_HOMOGENEOUS_EDGE_TYPE,
DEFAULT_HOMOGENEOUS_NODE_TYPE,
reverse_edge_type,
select_label_edge_types,
)
from gigl.utils.data_splitters import get_labels_for_anchor_nodes
[docs]
class DistABLPLoader(DistLoader):
def __init__(
self,
dataset: DistLinkPredictionDataset,
num_neighbors: Union[list[int], dict[EdgeType, list[int]]],
input_nodes: Optional[
Union[
torch.Tensor,
tuple[NodeType, torch.Tensor],
]
] = None,
# TODO(kmonte): Support multiple supervision edge types.
supervision_edge_type: Optional[EdgeType] = None,
num_workers: int = 1,
batch_size: int = 1,
pin_memory_device: Optional[torch.device] = None,
worker_concurrency: int = 4,
channel_size: str = "4GB",
process_start_gap_seconds: float = 60.0,
num_cpu_threads: Optional[int] = None,
shuffle: bool = False,
drop_last: bool = False,
context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this
local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this
local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this
):
"""
Neighbor loader for Anchor Based Link Prediction (ABLP) tasks.
Note that for this class, the dataset must *always* be heterogeneous,
as we need separate edge types for positive and negative labels.
By default, the loader will return {py:class} `torch_geometric.data.HeteroData` (heterogeneous) objects,
but will return a {py:class}`torch_geometric.data.Data` (homogeneous) object if the dataset is "labeled homogeneous".
The following fields may also be present:
- `y_positive`: `Dict[int, torch.Tensor]` mapping from local anchor node id to a tensor of positive
label node ids.
- `y_negative`: (Optional) `Dict[int, torch.Tensor]` mapping from local anchor node id to a tensor of negative
label node ids. This will only be present if the supervision edge type has negative labels.
NOTE: for both y_positive, and y_negative, the values represented in both the key and value of the dicts are
the *local* node ids of the sampled nodes, not the global node ids.
In order to get the global node ids, you can use the `node` field of the Data/HeteroData object.
e.g. global_positive_node_id_labels = data.node[data.y_positive[local_anchor_node_id]].
The underlying graph engine may also add the following fields to the output Data object:
- num_sampled_nodes: If heterogeneous. a dictionary mapping from node type to the number of sampled nodes for that type, by hop.
if homogeneous, a tensor the number of sampled nodes, by hop.
- num_sampled_edges: If heterogeneous, a dictionary mapping from edge type to the number of sampled edges for that type, by hop.
If homogeneous, a tensor denoting the number of sampled edges, by hop.
Let's use the following homogeneous graph (https://is.gd/a8DK15) as an example:
0 -> 1 [label="Positive example" color="green"]
0 -> 2 [label="Negative example" color="red"]
0 -> {3, 4}
3 -> {5, 6}
4 -> {7, 8}
1 -> 9 # shouldn't be sampled
2 -> 10 # shouldn't be sampled
For sampling around node `0`, the fields on the output Data object will be:
- `y_positive`: {0: torch.tensor([1])} # 1 is the only positive label for node 0
- `y_negative`: {0: torch.tensor([2])} # 2 is the only negative label for node 0
Args:
dataset (DistLinkPredictionDataset): The dataset to sample from.
num_neighbors (list[int] or Dict[tuple[str, str, str], list[int]]):
The number of neighbors to sample for each node in each iteration.
If an entry is set to `-1`, all neighbors will be included.
In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for each individual edge type.
context (DistributedContext): Distributed context information of the current process.
local_process_rank (int): The local rank of the current process within a node.
local_process_world_size (int): The total number of processes within a node.
input_nodes (Optional[torch.Tensor, tuple[NodeType, torch.Tensor]]):
Indices of seed nodes to start sampling from.
If set to `None` for homogeneous settings, all nodes will be considered.
In heterogeneous graphs, this flag must be passed in as a tuple that holds
the node type and node indices. (default: `None`)
num_workers (int): How many workers to use (subprocesses to spwan) for
distributed neighbor sampling of the current process. (default: ``1``).
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
pin_memory_device (str, optional): The target device that the sampled
results should be copied to. If set to ``None``, the device is inferred based off of
(got by ``gigl.distributed.utils.device.get_available_device``). Which uses the
local_process_rank and torch.cuda.device_count() to assign the device. If cuda is not available,
the cpu device will be used. (default: ``None``).
worker_concurrency (int): The max sampling concurrency for each sampling
worker. Load testing has showed that setting worker_concurrency to 4 yields the best performance
for sampling. Although, you may whish to explore higher/lower settings when performance tuning.
(default: `4`).
channel_size (int or str): The shared-memory buffer size (bytes) allocated
for the channel. Can be modified for performance tuning; a good starting point is: ``num_workers * 64MB``
(default: "4GB").
process_start_gap_seconds (float): Delay between each process for initializing neighbor loader. At large scales,
it is recommended to set this value to be between 60 and 120 seconds -- otherwise multiple processes may
attempt to initialize dataloaders at overlapping times, which can cause CPU memory OOM.
num_cpu_threads (Optional[int]): Number of cpu threads PyTorch should use for CPU training/inference
neighbor loading; on top of the per process parallelism.
Defaults to `2` if set to `None` when using cpu training/inference.
shuffle (bool): Whether to shuffle the input nodes. (default: ``False``).
drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``).
context (deprecated - will be removed soon) (Optional[DistributedContext]): Distributed context information of the current process.
local_process_rank (deprecated - will be removed soon) (int): The local rank of the current process within a node.
local_process_world_size (deprecated - will be removed soon) (int): The total number of processes within a node.
"""
# Set self._shutdowned right away, that way if we throw here, and __del__ is called,
# then we can properly clean up and don't get extraneous error messages.
# We set to `True` as we don't need to cleanup right away, and this will get set
# to `False` in super().__init__()` e.g.
# https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_loader.py#L125C1-L126C1
self._shutdowned = True
node_world_size: int
node_rank: int
rank: int
world_size: int
local_rank: int
local_world_size: int
master_ip_address: str
should_cleanup_distributed_context: bool = False
if context:
assert (
local_process_world_size is not None
), "context: DistributedContext provided, so local_process_world_size must be provided."
assert (
local_process_rank is not None
), "context: DistributedContext provided, so local_process_rank must be provided."
master_ip_address = context.main_worker_ip_address
node_world_size = context.global_world_size
node_rank = context.global_rank
local_world_size = local_process_world_size
local_rank = local_process_rank
rank = node_rank * local_world_size + local_rank
world_size = node_world_size * local_world_size
if not torch.distributed.is_initialized():
logger.info(
"process group is not available, trying to torch.distributed.init_process_group to communicate necessary setup information."
)
should_cleanup_distributed_context = True
logger.info(
f"Initializing process group with master ip address: {master_ip_address}, rank: {rank}, world size: {world_size}, local_rank: {local_rank}, local_world_size: {local_world_size}"
)
torch.distributed.init_process_group(
backend="gloo", # We just default to gloo for this temporary process group
init_method=f"tcp://{master_ip_address}:{DEFAULT_MASTER_INFERENCE_PORT}",
rank=rank,
world_size=world_size,
)
else:
assert (
torch.distributed.is_initialized()
), f"context: DistributedContext is None, so process group must be initialized before constructing this object {self.__class__.__name__}."
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
rank_ip_addresses = gigl.distributed.utils.get_internal_ip_from_all_ranks()
master_ip_address = rank_ip_addresses[0]
count_ranks_per_ip_address = Counter(rank_ip_addresses)
local_world_size = count_ranks_per_ip_address[master_ip_address]
for rank_ip_address, count in count_ranks_per_ip_address.items():
if count != local_world_size:
raise ValueError(
f"All ranks must have the same number of processes, but found {count} processes for rank {rank} on ip {rank_ip_address}, expected {local_world_size}."
+ f"count_ranks_per_ip_address = {count_ranks_per_ip_address}"
)
node_world_size = len(count_ranks_per_ip_address)
local_rank = rank % local_world_size
node_rank = rank // local_world_size
del (
context,
local_process_rank,
local_process_world_size,
) # delete deprecated vars so we don't accidentally use them.
if not isinstance(dataset.graph, abc.Mapping):
raise ValueError(
f"The dataset must be heterogeneous for ABLP. Recieved dataset with graph of type: {type(dataset.graph)}"
)
self._is_input_heterogeneous: bool = False
if isinstance(input_nodes, tuple):
if supervision_edge_type is None:
raise ValueError(
"When using heterogeneous ABLP, you must provide supervision_edge_types."
)
self._is_input_heterogeneous = True
anchor_node_type, anchor_node_ids = input_nodes
# TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if
# this assumption is no longer valid and/or is too opinionated
assert (
supervision_edge_type[0] == anchor_node_type
), f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \
got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}"
supervision_node_type = supervision_edge_type[2]
if dataset.edge_dir == "in":
supervision_edge_type = reverse_edge_type(supervision_edge_type)
elif isinstance(input_nodes, torch.Tensor):
if supervision_edge_type is not None:
raise ValueError(
f"Expected supervision edge type to be None for homogeneous input nodes, got {supervision_edge_type}"
)
anchor_node_ids = input_nodes
anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE
supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
elif input_nodes is None:
if dataset.node_ids is None:
raise ValueError(
"Dataset must have node ids if input_nodes are not provided."
)
if isinstance(dataset.node_ids, abc.Mapping):
raise ValueError(
f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}"
)
if supervision_edge_type is not None:
raise ValueError(
f"Expected supervision edge type to be None for homogeneous input nodes, got {supervision_edge_type}"
)
anchor_node_ids = dataset.node_ids
anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE
supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
missing_edge_types = set([supervision_edge_type]) - set(dataset.graph.keys())
if missing_edge_types:
raise ValueError(
f"Missing edge types in dataset: {missing_edge_types}. Edge types in dataset: {dataset.graph.keys()}"
)
if len(anchor_node_ids.shape) != 1:
raise ValueError(
f"input_nodes must be a 1D tensor, got {anchor_node_ids.shape}."
)
(
self._positive_label_edge_type,
self._negative_label_edge_type,
) = select_label_edge_types(supervision_edge_type, dataset.graph.keys())
self._supervision_edge_type = supervision_edge_type
positive_labels, negative_labels = get_labels_for_anchor_nodes(
dataset=dataset,
node_ids=anchor_node_ids,
positive_label_edge_type=self._positive_label_edge_type,
negative_label_edge_type=self._negative_label_edge_type,
)
[docs]
self.to_device = (
pin_memory_device
if pin_memory_device
else gigl.distributed.utils.get_available_device(
local_process_rank=local_rank
)
)
# TODO(kmonte): stop setting fanout for positive/negative once GLT sampling is fixed.
num_neighbors = patch_fanout_for_sampling(
dataset.get_edge_types(), num_neighbors
)
if num_neighbors.keys() != dataset.graph.keys():
raise ValueError(
f"num_neighbors must have all edge types in the graph, received: {num_neighbors.keys()} with for graph with edge types {dataset.graph.keys()}"
)
hops = len(next(iter(num_neighbors.values())))
if not all(len(fanout) == hops for fanout in num_neighbors.values()):
raise ValueError(
f"num_neighbors must be a dict of edge types with the same number of hops. Received: {num_neighbors}"
)
curr_process_nodes = shard_nodes_by_process(
input_nodes=anchor_node_ids,
local_process_rank=local_rank,
local_process_world_size=local_world_size,
)
# Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize
# the memory overhead and CPU contention.
neighbor_loader_ports = gigl.distributed.utils.get_free_ports_from_master_node(
num_ports=local_world_size
)
neighbor_loader_port_for_current_rank = neighbor_loader_ports[local_rank]
logger.info(
f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} using device: {self.to_device} on port {neighbor_loader_port_for_current_rank}."
)
should_use_cpu_workers = self.to_device.type == "cpu"
if should_use_cpu_workers and num_cpu_threads is None:
logger.info(
"Using CPU workers, but found num_cpu_threads to be None. "
f"Will default setting num_cpu_threads to {DEFAULT_NUM_CPU_THREADS}."
)
num_cpu_threads = DEFAULT_NUM_CPU_THREADS
gigl.distributed.utils.init_neighbor_loader_worker(
master_ip_address=master_ip_address,
local_process_rank=local_rank,
local_process_world_size=local_world_size,
rank=node_rank,
world_size=node_world_size,
master_worker_port=neighbor_loader_port_for_current_rank,
device=self.to_device,
should_use_cpu_workers=should_use_cpu_workers,
# Lever to explore tuning for CPU based inference
num_cpu_threads=num_cpu_threads,
process_start_gap_seconds=process_start_gap_seconds,
)
logger.info(
f"Finished initializing neighbor loader worker: {local_rank}/{local_world_size}"
)
# Sets up worker options for the dataloader
dist_sampling_ports = gigl.distributed.utils.get_free_ports_from_master_node(
num_ports=local_world_size
)
dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank]
worker_options = MpDistSamplingWorkerOptions(
num_workers=num_workers,
worker_devices=[torch.device("cpu") for _ in range(num_workers)],
worker_concurrency=worker_concurrency,
# Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group
# need to be connected. Thus, we need master ip address and master port to
# initate the connection.
# Note that different groups of workers are independent, and thus
# the sampling processes in different groups should be independent, and should
# use different master ports.
master_addr=master_ip_address,
master_port=dist_sampling_port_for_current_rank,
# Load testing show that when num_rpc_threads exceed 16, the performance
# will degrade.
num_rpc_threads=min(dataset.num_partitions, 16),
rpc_timeout=600,
channel_size=channel_size,
pin_memory=self.to_device.type == "cuda",
)
if should_cleanup_distributed_context and torch.distributed.is_initialized():
logger.info(
f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__."
)
torch.distributed.destroy_process_group()
sampler_input = ABLPNodeSamplerInput(
node=curr_process_nodes,
input_type=anchor_node_type,
positive_labels=positive_labels,
negative_labels=negative_labels,
supervision_node_type=supervision_node_type,
)
sampling_config = SamplingConfig(
sampling_type=SamplingType.NODE,
num_neighbors=num_neighbors,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
with_edge=True,
collect_features=True,
with_neg=False,
with_weight=False,
edge_dir=dataset.edge_dir,
seed=None, # it's actually optional - None means random.
)
# Code below this point is taken from the GLT DistNeighborLoader.__init__() function (graphlearn_torch/python/distributed/dist_neighbor_loader.py).
# We do this so that we may override the DistSamplingProducer that is used with the GiGL implementation.
[docs]
self.sampling_type = sampling_config.sampling_type
[docs]
self.num_neighbors = sampling_config.num_neighbors
[docs]
self.batch_size = sampling_config.batch_size
[docs]
self.shuffle = sampling_config.shuffle
[docs]
self.drop_last = sampling_config.drop_last
[docs]
self.with_edge = sampling_config.with_edge
[docs]
self.with_weight = sampling_config.with_weight
[docs]
self.collect_features = sampling_config.collect_features
[docs]
self.edge_dir = sampling_config.edge_dir
[docs]
self.sampling_config = sampling_config
[docs]
self.worker_options = worker_options
# We can set shutdowned to false now
self._shutdowned = False
self._is_mp_worker = True
self._is_collocated_worker = False
self._is_remote_worker = False
[docs]
self.num_data_partitions = self.data.num_partitions
[docs]
self.data_partition_idx = self.data.partition_idx
self._set_ntypes_and_etypes(
self.data.get_node_types(), self.data.get_edge_types()
)
self._num_recv = 0
self._epoch = 0
current_ctx = get_context()
self._input_len = len(self.input_data)
self._input_type = self.input_data.input_type
self._num_expected = self._input_len // self.batch_size
if not self.drop_last and self._input_len % self.batch_size != 0:
self._num_expected += 1
if not current_ctx.is_worker():
raise RuntimeError(
f"'{self.__class__.__name__}': only supports "
f"launching multiprocessing sampling workers with "
f"a non-server distribution mode, current role of "
f"distributed context is {current_ctx.role}."
)
if self.data is None:
raise ValueError(
f"'{self.__class__.__name__}': missing input dataset "
f"when launching multiprocessing sampling workers."
)
# Launch multiprocessing sampling workers
self._with_channel = True
self.worker_options._set_worker_ranks(current_ctx)
self._channel = ShmChannel(
self.worker_options.channel_capacity, self.worker_options.channel_size
)
if self.worker_options.pin_memory:
self._channel.pin_memory()
self._mp_producer = DistSamplingProducer(
self.data,
self.input_data,
self.sampling_config,
self.worker_options,
self._channel,
)
self._mp_producer.init()
def _get_labels(
self, msg: SampleMessage
) -> tuple[SampleMessage, torch.Tensor, Optional[torch.Tensor]]:
# TODO (mkolodner-sc): Remove the need to modify metadata once GLT's `to_hetero_data` function is fixed
"""
Gets the labels from the output SampleMessage and removes them from the metadata. We need to remove the labels from GLT's metadata since the
`to_hetero_data` function strangely assumes that we are doing edge-based sampling if the metadata is not empty at the time of
building the HeteroData object.
Args:
msg (SampleMessage): All possible results from a sampler, including subgraph data, features, and used defined metadata
Returns:
SampleMessage: Updated sample messsage with the label fields removed
torch.Tensor: Positive label ID tensor, where the ith row corresponds to the ith anchor node ID
Optional[torch.Tensor]: Negative label ID tensor, where the ith row corresponds to the ith anchor node ID, can be None if dataset has no negative labels
"""
metadata = {}
for k in list(msg.keys()):
if k.startswith("#META."):
meta_key = str(k[6:])
metadata[meta_key] = msg[k].to(self.to_device)
del msg[k]
positive_labels = metadata["positive_labels"]
negative_labels = (
metadata["negative_labels"] if "negative_labels" in metadata else None
)
return (msg, positive_labels, negative_labels)
def _set_labels(
self,
data: Union[Data, HeteroData],
positive_labels: torch.Tensor,
negative_labels: Optional[torch.Tensor],
) -> Union[Data, HeteroData]:
"""
Sets the labels and relevant fields in the torch_geometric Data object, converting the global node ids for labels to their
local index. Removes inserted supervision edge type from the data variables, since this is an implementation detail and should not be
exposed in the final HeteroData/Data object.
Args:
data (Union[Data, HeteroData]): Graph to provide labels for
positive_labels (torch.Tensor): Positive label ID tensor, where the ith row corresponds to the ith anchor node ID
negative_labels (Optional[torch.Tensor]): Negative label ID tensor, where the ith row corresponds to the ith anchor node ID,
can be None if dataset has no negative labels
Returns:
Union[Data, HeteroData]: torch_geometric HeteroData/Data object with the filtered edge fields and labels set as properties of the instance
"""
local_node_to_global_node: torch.Tensor
# shape [N], where N is the number of nodes in the subgraph, and local_node_to_global_node[i] gives the global node id for local node id `i`
if isinstance(data, HeteroData):
supervision_node_type = (
self._supervision_edge_type[0]
if self.edge_dir == "in"
else self._supervision_edge_type[2]
)
local_node_to_global_node = data[supervision_node_type].node
else:
local_node_to_global_node = data.node
output_positive_labels: dict[int, torch.Tensor] = {}
output_negative_labels: dict[int, torch.Tensor] = {}
for local_anchor_node_id in range(positive_labels.size(0)):
positive_mask = (
local_node_to_global_node.unsqueeze(1)
== positive_labels[local_anchor_node_id]
) # shape [N, P], where N is the number of nodes and P is the number of positive labels for the current anchor node
# Gets the indexes of the items in local_node_to_global_node which match any of the positive labels for the current anchor node
output_positive_labels[local_anchor_node_id] = torch.nonzero(positive_mask)[
:, 0
].to(self.to_device)
# Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the positive labels for the current anchor node
if negative_labels is not None:
negative_mask = (
local_node_to_global_node.unsqueeze(1)
== negative_labels[local_anchor_node_id]
) # shape [N, M], where N is the number of nodes and M is the number of negative labels for the current anchor node
# Gets the indexes of the items in local_node_to_global_node which match any of the negative labels for the current anchor node
output_negative_labels[local_anchor_node_id] = torch.nonzero(
negative_mask
)[:, 0].to(self.to_device)
# Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the negative labels for the current anchor node
data.y_positive = output_positive_labels
if negative_labels is not None:
data.y_negative = output_negative_labels
return data
def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
msg, positive_labels, negative_labels = self._get_labels(msg)
data = super()._collate_fn(msg)
if isinstance(data, HeteroData):
data = strip_label_edges(data)
if not self._is_input_heterogeneous:
data = labeled_to_homogeneous(self._supervision_edge_type, data)
data = self._set_labels(data, positive_labels, negative_labels)
return data