Source code for gigl.transforms.add_positional_encodings

from typing import Optional

import torch
from torch_geometric.data import HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_torch_sparse_tensor

from gigl.transforms.utils import add_node_attr

r"""
Positional and Structural Encodings for Heterogeneous Graphs.

This module provides PyG-compatible transforms for adding positional and structural
encodings to HeteroData objects. All transforms follow the PyG BaseTransform interface
and can be composed using `torch_geometric.transforms.Compose`.

Available Transforms:
    - AddHeteroRandomWalkEncodings: Combined random walk PE and SE in single pass
    - AddHeteroHopDistanceEncoding: Shortest path distance encoding

Example Usage:
    >>> from torch_geometric.data import HeteroData
    >>> from torch_geometric.transforms import Compose
    >>> from gigl.transforms.add_positional_encodings import (
    ...     AddHeteroRandomWalkEncodings,
    ...     AddHeteroHopDistanceEncoding,
    ... )
    >>>
    >>> # Create a heterogeneous graph
    >>> data = HeteroData()
    >>> data['user'].x = torch.randn(5, 16)
    >>> data['item'].x = torch.randn(3, 16)
    >>> data['user', 'buys', 'item'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]])
    >>> data['item', 'bought_by', 'user'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]])
    >>>
    >>> # Apply random walk encoding transform (computes both PE and SE)
    >>> transform = AddHeteroRandomWalkEncodings(walk_length=8)
    >>> data = transform(data)
    >>> print(data['user'].random_walk_pe.shape)  # (5, 8)
    >>> print(data['user'].random_walk_se.shape)  # (5, 8)
    >>>
    >>> # Compose with hop distance encoding
    >>> transform = Compose([
    ...     AddHeteroRandomWalkEncodings(walk_length=8),
    ...     AddHeteroHopDistanceEncoding(h_max=3),
    ... ])
    >>> data = transform(data)
    >>>
    >>> # For Graph Transformers, use hop distance encoding for attention bias
    >>> # Returns sparse matrix (0 for unreachable, 1-h_max for reachable pairs)
    >>> transform = AddHeteroHopDistanceEncoding(h_max=5)
    >>> data = transform(data)
    >>> print(data.hop_distance.shape)       # (num_total_nodes, num_total_nodes) sparse
    >>> print(data.hop_distance.is_sparse)   # True
"""


@functional_transform("add_hetero_random_walk_encodings")
[docs] class AddHeteroRandomWalkEncodings(BaseTransform): r"""Adds both random walk positional and structural encodings to the given heterogeneous graph (functional name: :obj:`add_hetero_random_walk_encodings`). This transform computes both encodings in a single pass over the random walk matrix, which is more efficient than applying separate transforms. **Positional Encoding (PE):** For each node j, computes the sum of transition probabilities from all other nodes to j after k steps of a random walk, for k = 1, 2, ..., walk_length. This captures how "reachable" or "central" a node is from the rest of the graph. The encoding is the column sum of non-diagonal elements of the k-step random walk matrix: PE[j, k] = Σ_{i≠j} (P^k)[i, j] where P is the transition matrix. This measures the probability mass flowing into node j from all other nodes at step k. **Structural Encoding (SE):** For each node, computes the probability of returning to itself after k steps of a random walk, for k = 1, 2, ..., walk_length. This captures the local structural role of each node (e.g., cycles, clustering coefficient). The encoding is the diagonal of the k-step random walk matrix: SE[i, k] = (P^k)[i, i] Based on the approach from `"Graph Neural Networks with Learnable Structural and Positional Representations" <https://arxiv.org/abs/2110.07875>`_. Args: walk_length (int): The number of random walk steps. pe_attr_name (str, optional): The attribute name of the positional encoding. (default: :obj:`"random_walk_pe"`) se_attr_name (str, optional): The attribute name of the structural encoding. (default: :obj:`"random_walk_se"`) is_undirected (bool, optional): If set to :obj:`True`, the graph is assumed to be undirected, and the adjacency matrix will be made symmetric. (default: :obj:`False`) attach_to_x (bool, optional): If set to :obj:`True`, the encodings are concatenated directly to :obj:`data[node_type].x` for each node type instead of being stored as separate attributes. PE is concatenated first, then SE. (default: :obj:`False`) """ def __init__( self, walk_length: int, pe_attr_name: Optional[str] = "random_walk_pe", se_attr_name: Optional[str] = "random_walk_se", is_undirected: bool = False, attach_to_x: bool = False, ) -> None:
[docs] self.walk_length = walk_length
[docs] self.pe_attr_name = pe_attr_name
[docs] self.se_attr_name = se_attr_name
[docs] self.is_undirected = is_undirected
[docs] self.attach_to_x = attach_to_x
[docs] def forward(self, data: HeteroData) -> HeteroData: assert isinstance(data, HeteroData), ( f"'{self.__class__.__name__}' only supports 'HeteroData' " f"(got '{type(data)}')" ) # Convert to homogeneous homo_data = data.to_homogeneous() edge_index = homo_data.edge_index num_nodes = homo_data.num_nodes if num_nodes == 0: for node_type in data.node_types: empty_encoding = torch.zeros( (data[node_type].num_nodes, self.walk_length), dtype=torch.float, ) # Handle PE effective_pe_attr_name = None if self.attach_to_x else self.pe_attr_name add_node_attr( data, {node_type: empty_encoding.clone()}, effective_pe_attr_name ) # Handle SE effective_se_attr_name = None if self.attach_to_x else self.se_attr_name add_node_attr( data, {node_type: empty_encoding.clone()}, effective_se_attr_name ) return data # Compute transition matrix (row-stochastic) using sparse operations adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) if self.is_undirected: # Make symmetric for undirected graphs adj = (adj + adj.t()).coalesce() # Compute degree for row normalization adj_coalesced = adj.coalesce() deg = torch.zeros(num_nodes, device=edge_index.device) deg.scatter_add_(0, adj_coalesced.indices()[0], adj_coalesced.values().float()) deg = torch.clamp(deg, min=1) # Avoid division by zero # Create row-normalized transition matrix (sparse) # P[i,j] = A[i,j] / deg[i] row_indices = adj_coalesced.indices()[0] normalized_values = adj_coalesced.values().float() / deg[row_indices] transition = torch.sparse_coo_tensor( adj_coalesced.indices(), normalized_values, size=(num_nodes, num_nodes), ).coalesce() # Compute both PE and SE in a single pass over the random walk # PE[j, k] = sum of column j excluding diagonal = Σ_{i≠j} (P^k)[i, j] # SE[i, k] = diagonal element = (P^k)[i, i] pe = torch.zeros( (num_nodes, self.walk_length), dtype=torch.float, device=edge_index.device ) se = torch.zeros( (num_nodes, self.walk_length), dtype=torch.float, device=edge_index.device ) # Start with identity matrix (sparse) identity_indices = torch.arange(num_nodes, device=edge_index.device) current = torch.sparse_coo_tensor( torch.stack([identity_indices, identity_indices]), torch.ones(num_nodes, device=edge_index.device), size=(num_nodes, num_nodes), ).coalesce() for k in range(self.walk_length): current = torch.sparse.mm(current, transition).coalesce() # Column sum = sum over rows for each column (for PE) col_sum = torch.zeros(num_nodes, device=edge_index.device) col_sum.scatter_add_(0, current.indices()[1], current.values()) # Extract diagonal elements (for both PE and SE) diag = torch.zeros(num_nodes, device=edge_index.device) diag_mask = current.indices()[0] == current.indices()[1] if diag_mask.any(): diag.scatter_add_( 0, current.indices()[0][diag_mask], current.values()[diag_mask] ) # PE: column sum excluding diagonal pe[:, k] = col_sum - diag # SE: diagonal elements (return probability) se[:, k] = diag # Map back to HeteroData node types # If attach_to_x is True, pass None as attr_name to concatenate to x directly effective_pe_attr_name = None if self.attach_to_x else self.pe_attr_name add_node_attr(data, pe, effective_pe_attr_name) effective_se_attr_name = None if self.attach_to_x else self.se_attr_name add_node_attr(data, se, effective_se_attr_name) return data
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(walk_length={self.walk_length}, " f"attach_to_x={self.attach_to_x})" )
@functional_transform("add_hetero_hop_distance_encoding")
[docs] class AddHeteroHopDistanceEncoding(BaseTransform): r"""Adds hop distance positional encoding as relative encoding (sparse CSR). For each pair of nodes (vi, vj), computes the shortest path distance p(vi, vj). This captures structural proximity and can be used with a learnable embedding matrix: h_hop(vi, vj) = W_hop · onehot(p(vi, vj)) Based on the approach from `"Do Transformers Really Perform Bad for Graph Representation?" <https://arxiv.org/abs/2106.05234>`_ (Graphormer). The output is a **sparse CSR matrix** where: - Reachable pairs (i, j) within h_max hops have value = hop distance (1 to h_max) - Unreachable pairs have value = 0 (not stored in sparse tensor) - Self-loops (diagonal) are not stored (distance to self is implicitly 0) CSR format is used for efficient row-based lookups during sequence building. Args: h_max (int): Maximum hop distance to consider. Distances > h_max are treated as unreachable (value 0 in sparse matrix). Set to 2-3 for 2-hop sampled subgraphs. Set to min(walk_length // 2, 10) for random walk sampled subgraphs. attr_name (str, optional): The attribute name of the positional encoding. (default: :obj:`"hop_distance"`) is_undirected (bool, optional): If set to :obj:`True`, the graph is assumed to be undirected for distance computation. (default: :obj:`False`) """ def __init__( self, h_max: int, attr_name: Optional[str] = "hop_distance", is_undirected: bool = False, ) -> None:
[docs] self.h_max = h_max
[docs] self.attr_name = attr_name
[docs] self.is_undirected = is_undirected
[docs] def forward(self, data: HeteroData) -> HeteroData: assert isinstance(data, HeteroData), ( f"'{self.__class__.__name__}' only supports 'HeteroData' " f"(got '{type(data)}')" ) # Convert to homogeneous to compute shortest paths homo_data = data.to_homogeneous() edge_index = homo_data.edge_index num_nodes = homo_data.num_nodes num_edges = edge_index.size(1) if num_nodes == 0 or num_edges == 0: # Handle empty graph case - return empty sparse CSR tensor empty_sparse = torch.sparse_csr_tensor( torch.zeros(num_nodes + 1, dtype=torch.long), torch.zeros(0, dtype=torch.long), torch.zeros(0, dtype=torch.float), size=(num_nodes, num_nodes), ) data[self.attr_name] = empty_sparse return data device = edge_index.device # Build sparse adjacency matrix for shortest path computation adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) if self.is_undirected: # Make symmetric for undirected graphs adj = (adj + adj.t()).coalesce() # Binarize adjacency (sparse) adj_coalesced = adj.coalesce() adj = torch.sparse_coo_tensor( adj_coalesced.indices(), torch.ones(adj_coalesced.indices().size(1), device=device), size=(num_nodes, num_nodes), ).coalesce() # Memory-optimized BFS for computing shortest path distances # # Key memory optimizations: # 1. Use sorted linear indices with searchsorted for O(n log n) membership test # (more memory efficient than torch.isin which may create hash tables) # 2. Store distances as int8 (h_max typically < 127) # 3. Avoid tensor concatenation in hot loop - use pre-sorted merge instead # 4. Explicit del statements to trigger garbage collection # 5. CSR format for sparse matmul (more memory efficient than COO) # # Memory complexity: O(nnz_frontier + nnz_visited) per iteration # where nnz_frontier can grow up to O(n^2) for dense graphs at large hop dist_matrix_rows = [] dist_matrix_cols = [] dist_matrix_vals = [] # Choose tracking strategy based on graph size # For small graphs (n < 10000), bitmap is faster with O(1) lookup # For large graphs, sorted indices use less memory O(visited) vs O(n^2/8) USE_BITMAP = num_nodes < 10000 if USE_BITMAP: # Dense bitmap: O(n^2 / 8) bytes, O(1) lookup # For n=10000, this is ~12.5 MB visited_bitmap = torch.zeros( num_nodes, num_nodes, dtype=torch.bool, device=device ) visited_bitmap.fill_diagonal_(True) # Mark diagonal as visited else: # Sorted linear indices: O(visited pairs) memory, O(log n) lookup identity_indices = torch.arange(num_nodes, device=device, dtype=torch.long) visited_linear = identity_indices * num_nodes + identity_indices # Diagonal visited_linear = visited_linear.sort()[0] # Adjacency matrix in CSR format (more memory efficient for matmul) adj_csr = adj.to_sparse_csr() del adj # Free COO adjacency # Current frontier (reachable pairs at current hop distance) frontier = adj_csr.to_sparse_coo().coalesce() for hop in range(1, self.h_max + 1): if hop > 1: # frontier = frontier @ adj (sparse matmul) # CSR @ CSR is most efficient frontier_csr = frontier.to_sparse_csr() del frontier frontier = ( torch.sparse.mm(frontier_csr, adj_csr).to_sparse_coo().coalesce() ) del frontier_csr frontier_indices = frontier.indices() num_frontier = frontier_indices.size(1) if num_frontier == 0: break reach_i, reach_j = frontier_indices[0], frontier_indices[1] if USE_BITMAP: # O(1) lookup using dense bitmap is_visited = visited_bitmap[reach_i, reach_j] is_new = ~is_visited del is_visited else: # O(log n) lookup using sorted searchsorted frontier_linear = reach_i.long() * num_nodes + reach_j.long() insert_pos = torch.searchsorted(visited_linear, frontier_linear) insert_pos_clamped = insert_pos.clamp(max=visited_linear.size(0) - 1) is_visited = visited_linear[insert_pos_clamped] == frontier_linear is_new = ~is_visited del frontier_linear, insert_pos, insert_pos_clamped, is_visited num_new = int(is_new.sum().item()) if num_new > 0: new_i = reach_i[is_new] new_j = reach_j[is_new] dist_matrix_rows.append(new_i) dist_matrix_cols.append(new_j) # Use int8 for hop distance (saves 4x memory vs float32) dist_matrix_vals.append( torch.full((num_new,), hop, device=device, dtype=torch.int8) ) # Update visited if USE_BITMAP: visited_bitmap[new_i, new_j] = True else: new_linear = new_i.long() * num_nodes + new_j.long() visited_linear = torch.cat([visited_linear, new_linear]).sort()[0] del new_linear del is_new, reach_i, reach_j # Clean up if USE_BITMAP: del visited_bitmap else: del visited_linear del adj_csr, frontier # Build sparse distance matrix if dist_matrix_rows: dist_rows = torch.cat(dist_matrix_rows) dist_cols = torch.cat(dist_matrix_cols) # Convert int8 to float for downstream compatibility dist_vals = torch.cat(dist_matrix_vals).float() # Free intermediate lists del dist_matrix_rows, dist_matrix_cols, dist_matrix_vals else: dist_rows = torch.zeros(0, dtype=torch.long, device=device) dist_cols = torch.zeros(0, dtype=torch.long, device=device) dist_vals = torch.zeros(0, dtype=torch.float, device=device) # Create sparse distance matrix in CSR format directly # CSR is more efficient for row-based lookups in _lookup_csr_values # Unreachable pairs have value 0 (not stored) # Reachable pairs have value = hop distance (1 to h_max) dist_coo = torch.sparse_coo_tensor( torch.stack([dist_rows, dist_cols]), dist_vals, size=(num_nodes, num_nodes), ).coalesce() dist_sparse = dist_coo.to_sparse_csr() del dist_coo # Store sparse pairwise distance matrix as graph-level attribute # Access via: data.hop_distance or data['hop_distance'] # Usage in attention: use sparse indexing for memory efficiency # Note: Node ordering follows data.to_homogeneous() order (by node_type alphabetically) data[self.attr_name] = dist_sparse return data
def __repr__(self) -> str: return f"{self.__class__.__name__}(h_max={self.h_max})"