"""
Utility functions for computing node degrees in distributed graph settings.
This module provides functions to compute node out-degrees from graph partitions
and aggregate them across distributed machines. Degrees are computed from the
CSR (Compressed Sparse Row) topology stored in GraphLearn-Torch Graph objects.
For homogeneous graphs, callers receive a single ``torch.Tensor``. For
heterogeneous graphs, degrees are accumulated per anchor node type (summing
across all edge types incident to that node type) before the distributed
all-reduce, so callers receive ``dict[NodeType, torch.Tensor]``.
Requirements
============
torch.distributed must be initialized before calling these functions.
Usage
=====
Access dataset.degree_tensor to lazily compute and cache the degree tensor.
Over-counting correction is handled automatically in _all_reduce_degrees by
detecting how many processes share the same machine (and thus the same data).
Heterogeneous partitioned graphs are expected to materialize all registered
non-label edge types on every rank, even when a rank has no local edges for a
type. This keeps the per-node-type all-reduce order consistent across ranks.
Degree tensors are stored as int32 to stay aligned with the PPR C++ sampler's
total-degree tensor requirement while keeping memory lower than int64. We avoid
int16 because it has caused compatibility issues in this path: during the C++
PPR sampler migration, ``torch.distributed.all_reduce`` on an int16 tensor
produced ``RuntimeError: Invalid scalar type``. Values above the int32 maximum
are clamped before casting to avoid wraparound.
"""
from collections import Counter
from typing import Final, Union, overload
import torch
from graphlearn_torch.data import Graph
from graphlearn_torch.typing import NodeType
from torch_geometric.typing import EdgeType
from gigl.common.logger import Logger
from gigl.distributed.utils.device import get_device_from_process_group
from gigl.distributed.utils.networking import get_internal_ip_from_all_ranks
from gigl.types.graph import is_label_edge_type
_INT32_MAX: Final[int] = torch.iinfo(torch.int32).max
[docs]
def compute_and_broadcast_degree_tensor(
graph: Union[Graph, dict[EdgeType, Graph]],
edge_dir: str,
) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]:
"""Compute node degrees from a graph and aggregate across all machines.
For each non-label edge type, degrees are derived from the CSR row pointers
(indptr). For heterogeneous graphs, degrees are summed across all edge types
incident to each anchor node type **locally** before the all-reduce, so the
per-edge-type tensor is only a transient intermediate and is never stored,
returned, or transmitted over RPC.
Over-counting correction (for processes sharing the same data) is handled
automatically by detecting the distributed topology.
Args:
graph: A Graph (homogeneous) or dict[EdgeType, Graph] (heterogeneous).
For heterogeneous graphs, label edge types are automatically excluded
— they are supervision edges and should not contribute to node degree
for graph traversal algorithms like PPR.
edge_dir: Sampling direction — ``"in"`` or ``"out"``. Determines which
end of each edge is the anchor node type for degree accumulation.
Returns:
Union[torch.Tensor, dict[NodeType, torch.Tensor]]: Aggregated degree
tensors. For homogeneous graphs, returns an int32 tensor of shape
``[num_nodes]``. For heterogeneous graphs, returns int32 tensors
keyed by node type with shape ``[num_nodes_of_that_type]``.
Raises:
RuntimeError: If torch.distributed is not initialized.
ValueError: If topology is unavailable.
"""
if not torch.distributed.is_initialized():
raise RuntimeError(
"compute_and_broadcast_degree_tensor requires torch.distributed to be initialized."
)
# Compute local degrees from graph topology.
if isinstance(graph, Graph):
topo = graph.topo
if topo is None or topo.indptr is None:
raise ValueError("Topology/indptr not available for graph.")
# Homogeneous graphs keep the usual GiGL shape: a single tensor.
result = _all_reduce_degrees(_compute_degrees_from_indptr(topo.indptr))
if result.numel() > 0:
logger.info(
f"{result.size(0)} nodes, max={result.max().item()}, min={result.min().item()}"
)
else:
logger.info("Graph contained 0 nodes when computing degrees")
return result
local_dict: dict[NodeType, torch.Tensor] = {}
for edge_type, edge_graph in graph.items():
# Label edge types are supervision edges and should not contribute to
# node degree for traversal algorithms like PPR.
if is_label_edge_type(edge_type):
continue
anchor_type: NodeType = edge_type[-1] if edge_dir == "in" else edge_type[0]
topo = edge_graph.topo
if topo is None or topo.indptr is None:
logger.warning(
f"Topology/indptr not available for edge type {edge_type}, using empty tensor."
)
degrees = torch.empty(0, dtype=torch.int32)
else:
degrees = _compute_degrees_from_indptr(topo.indptr)
if anchor_type in local_dict:
existing = local_dict[anchor_type]
max_len = max(len(existing), len(degrees))
summed = _pad_to_size(existing, max_len).to(torch.int64)
summed[: len(degrees)] += degrees.to(torch.int64)
local_dict[anchor_type] = _clamp_to_int32(summed)
else:
local_dict[anchor_type] = degrees
# All-reduce across ranks after local per-node-type aggregation.
result = _all_reduce_degrees(local_dict)
for node_type, degrees in result.items():
if degrees.numel() > 0:
logger.info(
f"{node_type}: {degrees.size(0)} nodes, "
f"max={degrees.max().item()}, min={degrees.min().item()}"
)
else:
logger.info(
f"Graph contained 0 nodes for node type {node_type} when computing degrees"
)
return result
def _pad_to_size(tensor: torch.Tensor, target_size: int) -> torch.Tensor:
"""Pad tensor with zeros to reach target_size."""
if tensor.size(0) >= target_size:
return tensor
padding = torch.zeros(
target_size - tensor.size(0),
dtype=tensor.dtype,
device=tensor.device,
)
return torch.cat([tensor, padding])
def _clamp_to_int32(tensor: torch.Tensor) -> torch.Tensor:
"""Clamp degree values to int32 range before converting dtype."""
return tensor.clamp(max=_INT32_MAX).to(torch.int32)
def _compute_degrees_from_indptr(indptr: torch.Tensor) -> torch.Tensor:
"""Compute degrees from CSR row pointers: degree[i] = indptr[i+1] - indptr[i]."""
return _clamp_to_int32(indptr[1:] - indptr[:-1])
def _get_degree_reduce_context() -> tuple[int, torch.device]:
"""Return local-world-size correction factor and all-reduce device."""
if not torch.distributed.is_initialized():
raise RuntimeError(
"_all_reduce_degrees requires torch.distributed to be initialized."
)
# Compute local_world_size: number of processes on the same machine sharing data.
all_ips = get_internal_ip_from_all_ranks()
my_rank = torch.distributed.get_rank()
my_ip = all_ips[my_rank]
local_world_size = Counter(all_ips)[my_ip]
# NCCL backend requires CUDA tensors; Gloo works with CPU.
device = get_device_from_process_group()
return local_world_size, device
def _all_reduce_single_degree_tensor(
tensor: torch.Tensor,
local_world_size: int,
device: torch.device,
) -> torch.Tensor:
"""All-reduce a single tensor with size sync and over-counting correction."""
# Synchronize max size across all ranks.
local_size = torch.tensor([tensor.size(0)], dtype=torch.long, device=device)
torch.distributed.all_reduce(local_size, op=torch.distributed.ReduceOp.MAX)
max_size = int(local_size.item())
# Pad, convert to int64 for all_reduce, and move to the process-group device.
padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device)
torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM)
# Correct for over-counting and move back to CPU. Clamp before casting so
# high-degree nodes saturate instead of wrapping.
return _clamp_to_int32(padded // local_world_size).cpu()
@overload
def _all_reduce_degrees(
local_degrees: dict[NodeType, torch.Tensor],
) -> dict[NodeType, torch.Tensor]: ...
@overload
def _all_reduce_degrees(local_degrees: torch.Tensor) -> torch.Tensor: ...
def _all_reduce_degrees(
local_degrees: Union[torch.Tensor, dict[NodeType, torch.Tensor]],
) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]:
"""All-reduce degree tensors across ranks.
Moves tensors to GPU for the all-reduce if using NCCL backend (which
requires CUDA), otherwise keeps tensors on CPU (for Gloo backend).
Over-counting correction:
In distributed training, multiple processes on the same machine often
share the same graph partition data (via shared memory). When we
all-reduce degrees, each process contributes its "local" degrees — but
if 4 processes on one machine all read the same partition, that
partition's degrees get summed 4 times instead of 1.
Example: Machine A has 2 processes sharing partition with degrees [3, 5, 2].
Machine B has 2 processes sharing partition with degrees [1, 4, 6].
Without correction: all-reduce sums = [3+3+1+1, 5+5+4+4, 2+2+6+6]
= [8, 18, 16] (wrong!)
With correction: divide by local_world_size (2 per machine)
= [4, 9, 8] (correct: [3+1, 5+4, 2+6])
This function detects how many processes share the same machine by
comparing IP addresses, then divides by that count to correct the
over-counting.
Args:
local_degrees: Either a homogeneous degree tensor or a dict mapping
NodeType to local degree tensors.
Returns:
Aggregated degree tensors matching the input shape.
Raises:
RuntimeError: If torch.distributed is not initialized.
"""
local_world_size, device = _get_degree_reduce_context()
if isinstance(local_degrees, torch.Tensor):
return _all_reduce_single_degree_tensor(local_degrees, local_world_size, device)
# Heterogeneous case: all-reduce each node type in deterministic order.
result: dict[NodeType, torch.Tensor] = {}
for node_type in sorted(local_degrees.keys()):
result[node_type] = _all_reduce_single_degree_tensor(
local_degrees[node_type], local_world_size, device
)
return result