Source code for gigl.transforms.utils

from typing import Dict, Optional, Tuple, Union

import torch
from torch import Tensor
from torch_geometric.data import HeteroData

# Type alias for edge types in PyG HeteroData
[docs] EdgeType = Tuple[str, str, str]
[docs] def add_node_attr( data: HeteroData, values: Union[Tensor, Dict[str, Tensor]], attr_name: Optional[str] = None, node_type_to_idx: Optional[Dict[str, Tuple[int, int]]] = None, ) -> HeteroData: """Helper function to add node attributes to a HeteroData object. Args: data: The HeteroData object to modify. values: Either: - A tensor of values in homogeneous node order (requires node_type_to_idx or will be computed from data.node_types), OR - A dictionary mapping node types to tensors of values for each type. attr_name: The name of the attribute to add. If None, concatenates to existing `x` attribute for each node type (or creates it). node_type_to_idx: Optional mapping from node type to (start, end) indices. Only used when values is a tensor. If None, it will be computed from data.node_types. Returns: The modified HeteroData object. """ # If values is a dictionary, directly assign to each node type if isinstance(values, dict): for node_type, value in values.items(): if node_type not in data.node_types: continue _set_node_attr_for_type(data, node_type, value, attr_name) return data # Otherwise, values is a tensor in homogeneous order - split by node type if node_type_to_idx is None: # Build mapping from node type to (start, end) indices in homogeneous tensor # When HeteroData is converted to homogeneous, nodes are ordered by node type. # This mapping lets us slice the homogeneous tensor to get values for each type. # Example: if data has 3 'user' nodes and 2 'item' nodes: # node_type_to_idx = {'user': (0, 3), 'item': (3, 5)} node_type_to_idx = {} start_idx = 0 for node_type in data.node_types: num_type_nodes = data[node_type].num_nodes node_type_to_idx[node_type] = (start_idx, start_idx + num_type_nodes) start_idx += num_type_nodes for node_type in data.node_types: start, end = node_type_to_idx[node_type] value = values[start:end] _set_node_attr_for_type(data, node_type, value, attr_name) return data
def _set_node_attr_for_type( data: HeteroData, node_type: str, value: Tensor, attr_name: Optional[str], ) -> None: """Helper to set node attribute for a single node type.""" if attr_name is None: # Concatenate to existing x or create new x # Use getattr to safely get x attribute, returns None if not present x = getattr(data[node_type], "x", None) if x is not None: # Existing features found: concatenate new values to them # Reshape 1D tensor [num_nodes] to 2D [num_nodes, 1] for concatenation x = x.view(-1, 1) if x.dim() == 1 else x # Move value to same device/dtype as x, then concatenate along feature dim data[node_type].x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) else: # No existing features: use new values as x directly data[node_type].x = value else: data[node_type][attr_name] = value
[docs] def add_edge_attr( data: HeteroData, values: Union[Tensor, Dict[EdgeType, Tensor]], attr_name: Optional[str] = None, edge_type_to_idx: Optional[Dict[EdgeType, Tuple[int, int]]] = None, ) -> HeteroData: """Helper function to add edge attributes to a HeteroData object. Args: data: The HeteroData object to modify. values: Either: - A tensor of values in homogeneous edge order (requires edge_type_to_idx or will be computed from data.edge_types), OR - A dictionary mapping edge types to tensors of values for each type. attr_name: The name of the attribute to add. If None, concatenates to existing `edge_attr` attribute for each edge type (or creates it). edge_type_to_idx: Optional mapping from edge type to (start, end) indices. Only used when values is a tensor. If None, it will be computed from data.edge_types. Returns: The modified HeteroData object. """ # If values is a dictionary, directly assign to each edge type if isinstance(values, dict): for edge_type, value in values.items(): if edge_type not in data.edge_types: continue _set_edge_attr_for_type(data, edge_type, value, attr_name) return data # Otherwise, values is a tensor in homogeneous order - split by edge type if edge_type_to_idx is None: # Build mapping from edge type to (start, end) indices in homogeneous tensor # When HeteroData is converted to homogeneous, edges are ordered by edge type. # This mapping lets us slice the homogeneous tensor to get values for each type. # Example: if data has 3 'buys' edges and 2 'views' edges: # edge_type_to_idx = {('user', 'buys', 'item'): (0, 3), ('user', 'views', 'item'): (3, 5)} edge_type_to_idx = {} start_idx = 0 for edge_type in data.edge_types: num_type_edges = data[edge_type].num_edges edge_type_to_idx[edge_type] = (start_idx, start_idx + num_type_edges) start_idx += num_type_edges for edge_type in data.edge_types: start, end = edge_type_to_idx[edge_type] value = values[start:end] _set_edge_attr_for_type(data, edge_type, value, attr_name) return data
def _set_edge_attr_for_type( data: HeteroData, edge_type: EdgeType, value: Tensor, attr_name: Optional[str], ) -> None: """Helper to set edge attribute for a single edge type.""" if attr_name is None: # Concatenate to existing edge_attr or create new edge_attr # Use getattr to safely get edge_attr attribute, returns None if not present edge_attr = getattr(data[edge_type], "edge_attr", None) if edge_attr is not None: # Existing features found: concatenate new values to them # Reshape 1D tensor [num_edges] to 2D [num_edges, 1] for concatenation edge_attr = edge_attr.view(-1, 1) if edge_attr.dim() == 1 else edge_attr # Move value to same device/dtype as edge_attr, then concatenate along feature dim data[edge_type].edge_attr = torch.cat( [edge_attr, value.to(edge_attr.device, edge_attr.dtype)], dim=-1 ) else: # No existing features: use new values as edge_attr directly data[edge_type].edge_attr = value else: data[edge_type][attr_name] = value