Source code for gigl.src.common.models.graph_transformer.graph_transformer

"""Graph Transformer encoder for heterogeneous graphs.

Adapted from RelGT's LocalModule (https://github.com/snap-stanford/relgt).
Converts heterogeneous graph data into fixed-length sequences via
``heterodata_to_graph_transformer_input``, processes through a stack of pre-norm
transformer encoder layers, then produces per-node embeddings via
attention-weighted neighbor readout.

Conforms to the same forward interface as ``HGT`` and ``SimpleHGN`` in
``gigl.src.common.models.pyg.heterogeneous``, making it a drop-in
replacement as the encoder in ``LinkPredictionGNN``.
"""

import math
from typing import Callable, Literal, Optional, cast

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.data.hetero_data
from torch import Tensor

from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.transforms.graph_transformer import (
    PPR_WEIGHT_FEATURE_NAME,
    SequenceAuxiliaryData,
    TokenInputData,
    heterodata_to_graph_transformer_input,
)


def _get_node_type_positional_encodings(
    data: torch_geometric.data.hetero_data.HeteroData,
    node_type: NodeType,
    pe_attr_names: list[str],
    device: torch.device,
) -> Tensor:
    """Collect concatenated node-level PE for a single node type."""
    pe_parts = []
    sorted_node_types = sorted(data.node_types)
    node_store = data[node_type]

    for attr_name in pe_attr_names:
        if hasattr(node_store, attr_name):
            pe_parts.append(getattr(node_store, attr_name).to(device))
            continue

        attr_dim = None
        for other_node_type in sorted_node_types:
            other_store = data[other_node_type]
            if hasattr(other_store, attr_name):
                attr_dim = getattr(other_store, attr_name).size(-1)
                break
        if attr_dim is None:
            raise ValueError(
                f"Positional encoding '{attr_name}' not found in any node type."
            )
        pe_parts.append(torch.zeros(node_store.num_nodes, attr_dim, device=device))

    return torch.cat(pe_parts, dim=-1)


def _build_sinusoidal_sequence_position_table(
    max_seq_len: int,
    hid_dim: int,
) -> Tensor:
    """Build a standard sinusoidal absolute position table."""
    positions = torch.arange(max_seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, hid_dim, 2, dtype=torch.float) * (-math.log(10000.0) / hid_dim)
    )

    position_table = torch.zeros(max_seq_len, hid_dim, dtype=torch.float)
    position_table[:, 0::2] = torch.sin(positions * div_term)
    if hid_dim > 1:
        position_table[:, 1::2] = torch.cos(
            positions * div_term[: position_table[:, 1::2].shape[1]]
        )
    return position_table


# Supported activation functions for FeedForwardNetwork
_ACTIVATION_FNS = {
    "gelu": nn.GELU,
    "relu": nn.ReLU,
    "silu": nn.SiLU,  # Also known as Swish
    "tanh": nn.Tanh,
}

# XGLU activations use a gating mechanism: activation(xW) * xV
# where W and V are separate linear projections
_XGLU_BASE_ACTIVATIONS = {
    "geglu": F.gelu,
    "swiglu": F.silu,
    "reglu": F.relu,
}


[docs] class FeedForwardNetwork(nn.Module): """Two-layer feed-forward network with configurable activation. Supports standard activations (GELU, ReLU, SiLU) and XGLU family (SwiGLU, GeGLU, ReGLU) which use a gating mechanism. Note: This module does NOT include LayerNorm. Normalization should be applied externally (e.g., pre-norm in the transformer layer). Adapted from RelGT's FeedForwardNetwork. Args: model_dim: Model (input and output) dimension of the FFN. feedforward_dim: Inner dimension of the two-layer MLP. dropout_rate: Dropout probability applied after each linear layer. activation: Activation function name. Supported values: - Standard: "gelu" (default), "relu", "silu", "tanh" - XGLU family: "geglu", "swiglu", "reglu" XGLU activations use gating: activation(xW) * xV, which requires projecting to 2x feedforward_dim internally. """ def __init__( self, model_dim: int, feedforward_dim: int, dropout_rate: float = 0.1, activation: str = "gelu", ) -> None: super().__init__() self._activation_name = activation.lower() # Validate activation if ( self._activation_name not in _ACTIVATION_FNS and self._activation_name not in _XGLU_BASE_ACTIVATIONS ): supported = sorted( set(_ACTIVATION_FNS.keys()) | set(_XGLU_BASE_ACTIVATIONS.keys()) ) raise ValueError( f"Unsupported activation '{activation}'. " f"Supported: {supported}" ) self._is_xglu = self._activation_name in _XGLU_BASE_ACTIVATIONS # Type declarations for optional attributes self._xglu_base_activation: Optional[Callable[..., Tensor]] = None self._linear_in: Optional[nn.Linear] = None self._dropout_in: Optional[nn.Dropout] = None self._linear_out: Optional[nn.Linear] = None self._dropout_out: Optional[nn.Dropout] = None self._ffn: Optional[nn.Sequential] = None if self._is_xglu: # XGLU: project to 2x feedforward_dim, split, apply gating self._xglu_base_activation = cast( Callable[..., Tensor], _XGLU_BASE_ACTIVATIONS[self._activation_name] ) self._linear_in = nn.Linear(model_dim, feedforward_dim * 2) self._dropout_in = nn.Dropout(dropout_rate) self._linear_out = nn.Linear(feedforward_dim, model_dim) self._dropout_out = nn.Dropout(dropout_rate) else: # Standard activation activation_fn = _ACTIVATION_FNS[self._activation_name] self._ffn = nn.Sequential( nn.Linear(model_dim, feedforward_dim), activation_fn(), nn.Dropout(dropout_rate), nn.Linear(feedforward_dim, model_dim), nn.Dropout(dropout_rate), )
[docs] def reset_parameters(self) -> None: """Reinitialize all learnable parameters.""" if self._is_xglu: assert self._linear_in is not None assert self._linear_out is not None nn.init.xavier_uniform_(self._linear_in.weight) nn.init.zeros_(self._linear_in.bias) nn.init.xavier_uniform_(self._linear_out.weight) nn.init.zeros_(self._linear_out.bias) else: # Use xavier + zero bias for consistency with XGLU path and # GraphTransformerEncoderLayer (standard Transformer practice) assert self._ffn is not None for layer in self._ffn: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: nn.init.zeros_(layer.bias)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass. Args: x: Input tensor of shape ``(batch, seq, model_dim)``. Returns: Output tensor of shape ``(batch, seq, model_dim)``. """ if self._is_xglu: # XGLU gating: activation(x @ W1) * (x @ W2) # where W1 and W2 are the two halves of linear_in assert self._xglu_base_activation is not None assert self._linear_in is not None assert self._dropout_in is not None assert self._linear_out is not None assert self._dropout_out is not None x_proj = self._linear_in(x) # (batch, seq, feedforward_dim * 2) x_gate, x_value = x_proj.chunk( 2, dim=-1 ) # Each: (batch, seq, feedforward_dim) x = self._xglu_base_activation(x_gate) * x_value x = self._dropout_in(x) x = self._linear_out(x) x = self._dropout_out(x) else: assert self._ffn is not None x = self._ffn(x) return x
[docs] class GraphTransformerEncoderLayer(nn.Module): """Pre-norm transformer encoder layer with multi-head self-attention. Uses ``F.scaled_dot_product_attention`` which automatically selects the most efficient attention implementation (flash, memory-efficient, or math-based) based on input properties and hardware. Adapted from RelGT's EncoderLayer. Args: model_dim: Model dimension (d_model). num_heads: Number of attention heads. Must evenly divide model_dim. feedforward_dim: Inner dimension of the feed-forward network. dropout_rate: Dropout probability for feed-forward layers. attention_dropout_rate: Dropout probability for attention weights. activation: Activation function for the feed-forward network. Supported values: "gelu" (default), "relu", "silu", "tanh", "geglu", "swiglu", "reglu". Raises: ValueError: If model_dim is not divisible by num_heads. """ def __init__( self, model_dim: int, num_heads: int, feedforward_dim: int, dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, activation: str = "gelu", ) -> None: super().__init__() if model_dim % num_heads != 0: raise ValueError( f"model_dim ({model_dim}) must be divisible by " f"num_heads ({num_heads})" ) self._num_heads = num_heads self._head_dim = model_dim // num_heads self._attention_dropout_rate = attention_dropout_rate self._attention_norm = nn.LayerNorm(model_dim) self._query_projection = nn.Linear(model_dim, model_dim) self._key_projection = nn.Linear(model_dim, model_dim) self._value_projection = nn.Linear(model_dim, model_dim) self._output_projection = nn.Linear(model_dim, model_dim) self._dropout = nn.Dropout(dropout_rate) self._ffn_norm = nn.LayerNorm(model_dim) self._ffn = FeedForwardNetwork( model_dim, feedforward_dim, dropout_rate, activation=activation )
[docs] def reset_parameters(self) -> None: """Reinitialize all learnable parameters.""" self._attention_norm.reset_parameters() for projection in [ self._query_projection, self._key_projection, self._value_projection, self._output_projection, ]: nn.init.xavier_uniform_(projection.weight) if projection.bias is not None: nn.init.zeros_(projection.bias) self._ffn_norm.reset_parameters() self._ffn.reset_parameters()
[docs] def forward( self, x: Tensor, attn_bias: Optional[Tensor] = None, valid_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass. Args: x: Input tensor of shape ``(batch, seq, model_dim)``. attn_bias: Optional attention bias of shape ``(batch, num_heads, seq, seq)`` or broadcastable. Added as an additive mask to attention scores. valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used to zero out padded token states after each residual block. Returns: Output tensor of shape ``(batch, seq, model_dim)``. """ batch_size, seq_len, model_dim = x.shape # Self-attention block (pre-norm) residual = x x_norm = self._attention_norm(x) query = self._query_projection(x_norm) key = self._key_projection(x_norm) value = self._value_projection(x_norm) # Reshape to (batch, num_heads, seq, head_dim) query = query.view( batch_size, seq_len, self._num_heads, self._head_dim ).transpose(1, 2) key = key.view(batch_size, seq_len, self._num_heads, self._head_dim).transpose( 1, 2 ) value = value.view( batch_size, seq_len, self._num_heads, self._head_dim ).transpose(1, 2) attention_output = F.scaled_dot_product_attention( query, key, value, attn_mask=attn_bias, dropout_p=self._attention_dropout_rate if self.training else 0.0, is_causal=False, ) # Reshape back to (batch, seq, model_dim) attention_output = attention_output.transpose(1, 2).reshape( batch_size, seq_len, model_dim ) attention_output = self._output_projection(attention_output) attention_output = self._dropout(attention_output) x = residual + attention_output if valid_mask is not None: x = x * valid_mask.unsqueeze(-1).to(x.dtype) # Feed-forward block (pre-norm) residual = x x_norm = self._ffn_norm(x) ffn_output = self._ffn(x_norm) x = residual + ffn_output if valid_mask is not None: x = x * valid_mask.unsqueeze(-1).to(x.dtype) return x
[docs] class GraphTransformerEncoder(nn.Module): """Graph Transformer encoder for heterogeneous graphs. Converts heterogeneous graph data into fixed-length sequences via ``heterodata_to_graph_transformer_input``, processes through pre-norm transformer encoder layers, and produces per-node embeddings via attention-weighted neighbor readout (from RelGT's LocalModule). Conforms to the same forward interface as ``HGT`` and ``SimpleHGN``, making it a drop-in encoder for ``LinkPredictionGNN``. Args: node_type_to_feat_dim_map: Dictionary mapping node types to their input feature dimensions. edge_type_to_feat_dim_map: Dictionary mapping edge types to their feature dimensions. Accepted for interface conformance with ``HGT``/``SimpleHGN``; edge features are not used by the graph transformer. hid_dim: Hidden dimension for transformer layers. All node types are projected to this dimension before processing. out_dim: Output embedding dimension. num_layers: Number of transformer encoder layers. num_heads: Number of attention heads per layer. Must evenly divide ``hid_dim``. max_seq_len: Maximum sequence length for the graph-to-sequence transform. Neighborhoods are truncated to this length. hop_distance: Number of hops for neighborhood extraction in the graph-to-sequence transform when using ``"khop"`` sequence construction. sequence_construction_method: Sequence builder used to create tokens for each anchor. ``"khop"`` expands the sampled graph by hop distance, while ``"ppr"`` consumes outgoing ``"ppr"`` edges sorted by weight. sequence_positional_encoding_type: Optional sequence-level positional encoding applied after sequence construction. Supported values are ``None`` and ``"sinusoidal"``. Lower-cost future extensions could add learned absolute position embeddings here, while attention-level options like RoPE or ALiBi would require changes inside the attention block. dropout_rate: Dropout probability for feed-forward layers. attention_dropout_rate: Dropout probability for attention weights. should_l2_normalize_embedding_layer_output: Whether to L2 normalize output embeddings. pe_attr_names: List of node-level positional encoding attribute names. In ``"concat"`` mode these are concatenated to sequence features. In ``"add"`` mode they are projected to ``hid_dim`` and added to node features before sequence construction. anchor_based_attention_bias_attr_names: List of anchor-relative feature names used as additive attention bias for sequence keys. Sparse graph-level attributes are looked up from ``data`` and the reserved name ``"ppr_weight"`` resolves to PPR edge weights in PPR mode. Example: ``['hop_distance', 'ppr_weight']`` where ``hop_distance`` is a sparse matrix attribute on ``data`` and ``ppr_weight`` is extracted from PPR edge weights. anchor_based_input_attr_names: List of anchor-relative attribute names used as token-aligned input features. Sparse graph-level attributes are looked up from ``data`` and ``"ppr_weight"`` resolves to PPR edge weights in PPR mode. These are projected to ``hid_dim`` and added to the sequence tokens after sequence construction. Example: ``['hop_distance', 'ppr_weight']`` for continuous features, or ``['hop_distance']`` when ``hop_distance`` will be embedded via ``anchor_based_input_embedding_dict``. anchor_based_input_embedding_dict: Optional ModuleDict mapping a subset of ``anchor_based_input_attr_names`` to per-attribute embedding layers. These attributes are treated as discrete indices and their embedded contributions are added to the sequence tokens. Padding is masked out using the sequence valid mask. Example: ``nn.ModuleDict({'hop_distance': nn.Embedding(10, hid_dim)})`` to embed hop distances 0-9 into ``hid_dim``-dimensional vectors. The embedding output dimension must match ``hid_dim``. pairwise_attention_bias_attr_names: List of pairwise feature names used as additive attention bias. These must correspond to sparse graph-level attributes on ``data``. feature_embedding_layer_dict: Optional ModuleDict mapping node types to feature embedding layers. If provided, these are applied to node features before node projection. (default: None) pe_integration_mode: How to fuse positional encodings into the model input. ``"concat"`` preserves the current behavior by concatenating node-level PE to token features. ``"add"`` uses node-level additive PE before sequence construction and attention bias for relative encodings. activation: Activation function for the feed-forward network in each transformer layer. Supported values: - Standard: "gelu" (default), "relu", "silu", "tanh" - XGLU family: "geglu", "swiglu", "reglu" XGLU activations use gating: activation(xW) * xV. feedforward_ratio: Ratio of feedforward dimension to hidden dimension (feedforward_dim = hid_dim * feedforward_ratio). If None (default), uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU's gating doubles the effective parameters, so a smaller ratio maintains similar parameter count. Notes: This encoder uses ``nn.LazyLinear`` for node-level PE fusion. If you wrap it with ``DistributedDataParallel``, run one representative no-grad forward first, passing ``anchor_node_ids``/``anchor_node_type`` for the graph-transformer path, or load a checkpoint before DDP so all ranks see initialized weights. TODO: Pairwise relative bias is currently materialized densely for the selected sequence. That is fine for moderate ``max_seq_len``, but a chunked or sparse LPFormer-style path is still future work for larger sequences. Example: >>> from gigl.src.common.models.graph_transformer.graph_transformer import ( ... GraphTransformerEncoder, ... ) >>> encoder = GraphTransformerEncoder( ... node_type_to_feat_dim_map={NodeType("user"): 64, NodeType("item"): 32}, ... edge_type_to_feat_dim_map={}, ... hid_dim=128, ... out_dim=64, ... num_layers=2, ... num_heads=4, ... ) >>> embeddings = encoder(data, anchor_node_type=NodeType("user"), device=device) """ def __init__( self, node_type_to_feat_dim_map: dict[NodeType, int], edge_type_to_feat_dim_map: dict[EdgeType, int], hid_dim: int, out_dim: int = 128, num_layers: int = 2, num_heads: int = 2, max_seq_len: int = 128, hop_distance: int = 2, sequence_construction_method: Literal["khop", "ppr"] = "khop", sequence_positional_encoding_type: Optional[str] = None, dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, should_l2_normalize_embedding_layer_output: bool = False, pe_attr_names: Optional[list[str]] = None, anchor_based_attention_bias_attr_names: Optional[list[str]] = None, anchor_based_input_attr_names: Optional[list[str]] = None, anchor_based_input_embedding_dict: Optional[nn.ModuleDict] = None, pairwise_attention_bias_attr_names: Optional[list[str]] = None, feature_embedding_layer_dict: Optional[nn.ModuleDict] = None, pe_integration_mode: Literal["concat", "add"] = "concat", activation: str = "gelu", feedforward_ratio: Optional[float] = None, **kwargs: object, ) -> None: super().__init__() del kwargs if pe_integration_mode not in {"concat", "add"}: raise ValueError( "pe_integration_mode must be one of {'concat', 'add'}, " f"got '{pe_integration_mode}'" ) self._hid_dim = hid_dim self._out_dim = out_dim self._max_seq_len = max_seq_len self._hop_distance = hop_distance if sequence_construction_method not in {"khop", "ppr"}: raise ValueError( "sequence_construction_method must be one of {'khop', 'ppr'}, " f"got '{sequence_construction_method}'" ) if sequence_positional_encoding_type is not None: sequence_positional_encoding_type = ( sequence_positional_encoding_type.lower() ) if sequence_positional_encoding_type == "none": sequence_positional_encoding_type = None if sequence_positional_encoding_type not in {None, "sinusoidal"}: raise ValueError( "sequence_positional_encoding_type must be one of " "{None, 'sinusoidal'}, " f"got '{sequence_positional_encoding_type}'" ) if ( sequence_construction_method == "khop" and sequence_positional_encoding_type is not None ): raise ValueError( "sequence_positional_encoding_type requires " "sequence_construction_method='ppr' because khop sequences do not " "enforce a stable token order." ) anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] anchor_input_attr_names = anchor_based_input_attr_names or [] pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] if PPR_WEIGHT_FEATURE_NAME in pairwise_bias_attr_names: raise ValueError( f"'{PPR_WEIGHT_FEATURE_NAME}' is an anchor-relative feature and " "cannot be used as pairwise attention bias." ) if ( PPR_WEIGHT_FEATURE_NAME in anchor_bias_attr_names + anchor_input_attr_names and sequence_construction_method != "ppr" ): raise ValueError( "The reserved anchor-relative feature 'ppr_weight' requires " "sequence_construction_method='ppr'." ) self._sequence_construction_method = sequence_construction_method self._sequence_positional_encoding_type = sequence_positional_encoding_type self._should_l2_normalize_embedding_layer_output = ( should_l2_normalize_embedding_layer_output ) self._pe_attr_names = pe_attr_names self._anchor_based_attention_bias_attr_names = ( anchor_based_attention_bias_attr_names ) self._anchor_based_input_attr_names = anchor_based_input_attr_names self._anchor_based_input_embedding_dict = anchor_based_input_embedding_dict self._pairwise_attention_bias_attr_names = pairwise_attention_bias_attr_names self._feature_embedding_layer_dict = feature_embedding_layer_dict self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads anchor_input_embedding_attr_names = ( set(anchor_based_input_embedding_dict.keys()) if anchor_based_input_embedding_dict is not None else set() ) invalid_anchor_input_embedding_attr_names = ( anchor_input_embedding_attr_names - set(anchor_input_attr_names) ) if invalid_anchor_input_embedding_attr_names: raise ValueError( "anchor_based_input_embedding_dict keys must be a subset of " "anchor_based_input_attr_names, got unexpected keys " f"{sorted(invalid_anchor_input_embedding_attr_names)}." ) self._continuous_anchor_input_attr_names = [ attr_name for attr_name in anchor_input_attr_names if attr_name not in anchor_input_embedding_attr_names ] if self._sequence_positional_encoding_type == "sinusoidal": self.register_buffer( "_sequence_positional_encoding_table", _build_sinusoidal_sequence_position_table( max_seq_len=max_seq_len, hid_dim=hid_dim, ), persistent=False, ) else: self.register_buffer( "_sequence_positional_encoding_table", None, persistent=False, ) # Per-node-type input projection to hid_dim (like HGT's lin_dict) self._node_projection_dict = nn.ModuleDict( { str(node_type): nn.Linear(feat_dim, hid_dim) for node_type, feat_dim in node_type_to_feat_dim_map.items() } ) # PE fusion layers for node-level positional encodings. # In "concat" mode: projects [node_features || PE] → hid_dim # In "add" mode: projects PE → hid_dim, then adds to node features self._concat_pe_fusion_projection: Optional[nn.Module] = None has_node_level_pe = bool(pe_attr_names) if pe_integration_mode == "concat" and has_node_level_pe: self._concat_pe_fusion_projection = nn.LazyLinear(hid_dim) self._pe_projection: Optional[nn.Module] = None if pe_integration_mode == "add" and has_node_level_pe: self._pe_projection = nn.LazyLinear(hid_dim, bias=False) self._token_input_projection: Optional[nn.Module] = None if self._continuous_anchor_input_attr_names: self._token_input_projection = nn.LazyLinear(hid_dim, bias=False) self._anchor_pe_attention_bias_projection: Optional[nn.Linear] = None num_anchor_bias_attrs = len(self._anchor_based_attention_bias_attr_names or []) if num_anchor_bias_attrs > 0: self._anchor_pe_attention_bias_projection = nn.Linear( num_anchor_bias_attrs, num_heads, bias=False, ) self._pairwise_pe_attention_bias_projection: Optional[nn.Linear] = None if self._pairwise_attention_bias_attr_names: self._pairwise_pe_attention_bias_projection = nn.Linear( len(self._pairwise_attention_bias_attr_names), num_heads, bias=False, ) # Transformer encoder layers # Default feedforward ratio: 4.0 for standard activations, 8/3 for XGLU # XGLU's gating mechanism doubles effective parameters, so smaller ratio # maintains similar parameter count to standard activations with ratio 4. is_xglu = activation.lower() in _XGLU_BASE_ACTIVATIONS if feedforward_ratio is None: feedforward_ratio = 8.0 / 3.0 if is_xglu else 4.0 feedforward_dim = int(hid_dim * feedforward_ratio) self._encoder_layers = nn.ModuleList( [ GraphTransformerEncoderLayer( model_dim=hid_dim, num_heads=num_heads, feedforward_dim=feedforward_dim, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, activation=activation, ) for _ in range(num_layers) ] ) self._final_norm = nn.LayerNorm(hid_dim) # Readout attention: projects concatenated (anchor, neighbor) to score self._readout_attention = nn.Linear(2 * hid_dim, 1) # Output projection: hid_dim -> out_dim self._output_projection = nn.Linear(hid_dim, out_dim)
[docs] def forward( self, data: torch_geometric.data.hetero_data.HeteroData, anchor_node_type: Optional[NodeType] = None, anchor_node_ids: Optional[Tensor] = None, device: Optional[torch.device] = None, ) -> torch.Tensor: """Run the forward pass of the Graph Transformer encoder. Args: data: Input HeteroData object with node features (``x_dict``) and edge indices (``edge_index_dict``). anchor_node_type: Node type for which to compute embeddings. If None, uses the first node type in data. anchor_node_ids: Optional tensor of local node indices within anchor_node_type to use as anchors. If None, uses the first batch_size nodes (seed nodes from neighbor sampling). device: Torch device for output tensors. If None, inferred from data. Returns: Embeddings tensor of shape ``(num_anchor_nodes, out_dim)``. """ # Infer device from data if not provided if device is None: device = next(iter(data.x_dict.values())).device # Use first node type if not specified if anchor_node_type is None: anchor_node_type = list(data.node_types)[0] # 0. Apply feature embedding if provided (without modifying original data) # 1. Project all node features to hid_dim # Build a new x_dict with processed features to avoid in-place modifications projected_x_dict: dict[NodeType, torch.Tensor] = {} for node_type, x in data.x_dict.items(): x_processed = x.to(device) feature_embedding_layer = None if ( self._feature_embedding_layer_dict is not None and node_type in self._feature_embedding_layer_dict ): feature_embedding_layer = self._feature_embedding_layer_dict[node_type] # Apply feature embedding if available for this node type if feature_embedding_layer is not None: x_processed = feature_embedding_layer(x_processed) # Project to hid_dim x_projected = self._node_projection_dict[str(node_type)](x_processed) node_pe_parts = [] if self._pe_attr_names: node_pe_parts.append( _get_node_type_positional_encodings( data=data, node_type=node_type, pe_attr_names=self._pe_attr_names, device=device, ) ) if node_pe_parts: node_pe = torch.cat(node_pe_parts, dim=-1) if self._pe_integration_mode == "add": if self._pe_projection is None: raise ValueError("PE projection layer is not initialized.") x_projected = x_projected + self._pe_projection(node_pe) else: if self._concat_pe_fusion_projection is None: raise ValueError( "Concat PE fusion projection layer is not initialized." ) x_projected = self._concat_pe_fusion_projection( torch.cat([x_projected, node_pe], dim=-1) ) projected_x_dict[node_type] = x_projected # Create a new HeteroData with projected features (avoiding in-place modification) projected_data = torch_geometric.data.HeteroData() for node_type in data.node_types: projected_data[node_type].x = projected_x_dict[node_type] # Copy batch_size if it exists if hasattr(data[node_type], "batch_size"): projected_data[node_type].batch_size = data[node_type].batch_size for edge_type in data.edge_types: projected_data[edge_type].edge_index = data[edge_type].edge_index if hasattr(data[edge_type], "edge_attr"): projected_data[edge_type].edge_attr = data[edge_type].edge_attr # Copy relative-encoding attributes (e.g., hop_distance stored as sparse matrix) relative_pe_attr_names = { attr_name for attr_name in (self._anchor_based_attention_bias_attr_names or []) if attr_name != PPR_WEIGHT_FEATURE_NAME } relative_pe_attr_names.update(self._anchor_based_input_attr_names or []) relative_pe_attr_names.update(self._pairwise_attention_bias_attr_names or []) relative_pe_attr_names.discard(PPR_WEIGHT_FEATURE_NAME) if relative_pe_attr_names: for attr_name in sorted(relative_pe_attr_names): if hasattr(data, attr_name): setattr(projected_data, attr_name, getattr(data, attr_name)) # 2. Build sequences and run transformer # If anchor_node_ids provided, use those; otherwise use first batch_size nodes if anchor_node_ids is not None: num_anchor_nodes = anchor_node_ids.size(0) else: num_anchor_nodes = getattr( projected_data[anchor_node_type], "batch_size", projected_data[anchor_node_type].num_nodes, ) ( sequences, valid_mask, sequence_auxiliary_data, ) = heterodata_to_graph_transformer_input( data=projected_data, batch_size=num_anchor_nodes, max_seq_len=self._max_seq_len, anchor_node_type=anchor_node_type, anchor_node_ids=anchor_node_ids, hop_distance=self._hop_distance, sequence_construction_method=self._sequence_construction_method, anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, anchor_based_input_attr_names=self._anchor_based_input_attr_names, pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, ) # Free memory after sequences are built del projected_data if sequences.size(-1) != self._hid_dim: raise ValueError( f"Expected sequence dim {self._hid_dim} after node projection, " f"got {sequences.size(-1)}." ) token_input_features = sequence_auxiliary_data["token_input"] if token_input_features is not None: sequences = sequences + self._build_token_input_contribution( token_input_features=token_input_features, sequences=sequences, valid_mask=valid_mask, ) sequence_positional_encoding = self._get_sequence_positional_encoding( valid_mask=valid_mask, sequences=sequences, ) if sequence_positional_encoding is not None: sequences = sequences + sequence_positional_encoding attn_bias = self._build_attention_bias( valid_mask=valid_mask, sequences=sequences, attention_bias_data=sequence_auxiliary_data, ) embeddings = self._encode_and_readout( sequences=sequences, valid_mask=valid_mask, attn_bias=attn_bias, ) embeddings = self._output_projection(embeddings) if self._should_l2_normalize_embedding_layer_output: embeddings = F.normalize(embeddings, p=2, dim=-1) return embeddings
def _get_sequence_positional_encoding( self, valid_mask: Tensor, sequences: Tensor, ) -> Optional[Tensor]: if self._sequence_positional_encoding_type is None: return None if self._sequence_positional_encoding_type != "sinusoidal": raise ValueError( "Unsupported sequence_positional_encoding_type " f"'{self._sequence_positional_encoding_type}'." ) if self._sequence_positional_encoding_table is None: raise ValueError("Sequence positional encoding table is not initialized.") position_table = cast(Tensor, self._sequence_positional_encoding_table) seq_len = sequences.size(1) if seq_len > position_table.size(0): raise ValueError( f"Sequence length {seq_len} exceeds configured max_seq_len " f"{position_table.size(0)}." ) position_encoding = position_table[:seq_len] position_encoding = position_encoding.to( device=sequences.device, dtype=sequences.dtype, ) position_encoding = position_encoding.unsqueeze(0).expand( sequences.size(0), -1, -1 ) return position_encoding * valid_mask.unsqueeze(-1).to(sequences.dtype) def _build_token_input_contribution( self, token_input_features: TokenInputData, sequences: Tensor, valid_mask: Tensor, ) -> Tensor: token_contribution = torch.zeros_like(sequences) valid_token_mask = valid_mask.unsqueeze(-1).to(sequences.dtype) if self._anchor_based_input_embedding_dict is not None: for ( attr_name, embedding_layer, ) in self._anchor_based_input_embedding_dict.items(): if attr_name not in token_input_features: raise ValueError( f"Token-input feature '{attr_name}' is missing from the " "sequence auxiliary data." ) indices = token_input_features[attr_name] if indices.size(-1) != 1: raise ValueError( f"Embedded token-input feature '{attr_name}' must have " f"shape (batch, seq, 1), got {indices.shape}." ) embedded_attr = embedding_layer(indices.squeeze(-1).long()) if embedded_attr.shape != sequences.shape: raise ValueError( f"Embedded token-input feature '{attr_name}' must produce " f"shape {sequences.shape}, got {embedded_attr.shape}." ) token_contribution = token_contribution + ( embedded_attr.to(sequences.dtype) * valid_token_mask ) if self._continuous_anchor_input_attr_names: if self._token_input_projection is None: raise ValueError("Token-input projection is not initialized.") continuous_feature_parts: list[Tensor] = [] for attr_name in self._continuous_anchor_input_attr_names: if attr_name not in token_input_features: raise ValueError( f"Token-input feature '{attr_name}' is missing from the " "sequence auxiliary data." ) continuous_feature_parts.append(token_input_features[attr_name]) token_contribution = token_contribution + ( self._token_input_projection( torch.cat(continuous_feature_parts, dim=-1).to(sequences.dtype) ) * valid_token_mask ) return token_contribution def _build_attention_bias( self, valid_mask: Tensor, sequences: Tensor, attention_bias_data: SequenceAuxiliaryData, ) -> Tensor: """Build additive attention bias from padding mask and learned relative PE projections. This function constructs a combined attention bias tensor that is added to attention scores before softmax. The bias has three components: 1. **Padding mask bias**: Sets padded positions to -inf so they receive zero attention weight after softmax. Shape: (batch, 1, 1, seq) broadcasts to (batch, num_heads, seq, seq) for key masking. 2. **Anchor-relative bias** (optional): For each sequence position, looks up the PE value relative to the anchor (e.g., hop distance from anchor). Input shape: (batch, seq, num_anchor_attrs) After projection: (batch, num_heads, 1, seq) - same bias for all query positions. 3. **Pairwise bias** (optional): For each (query, key) pair, looks up the PE value between those two nodes (e.g., random walk structural encoding). Input shape: (batch, seq, seq, num_pairwise_attrs) After projection: (batch, num_heads, seq, seq) - unique bias per query-key pair. Args: valid_mask: Boolean mask of shape (batch_size, seq_len) indicating valid (non-padding) positions. sequences: Input sequences of shape (batch_size, seq_len, hid_dim), used only to infer dtype and device. attention_bias_data: Dictionary containing optional PE tensors: - "anchor_bias": (batch, seq, num_anchor_attrs) or None - "pairwise_bias": (batch, seq, seq, num_pairwise_attrs) or None Returns: Combined attention bias tensor of shape (batch_size, num_heads, seq_len, seq_len) or broadcastable shape. Added to attention scores before softmax. Example: # With batch_size=2, seq_len=4, num_heads=8 # valid_mask = [[T, T, T, F], [T, T, F, F]] # # Output attn_bias shape: (2, 8, 4, 4) # - Positions where valid_mask is False get -inf # - Anchor bias adds per-key bias (same for all queries) # - Pairwise bias adds unique bias for each (query, key) pair """ batch_size, seq_len = valid_mask.shape dtype = sequences.dtype device = sequences.device negative_inf = torch.finfo(dtype).min # Step 1: Initialize with padding mask bias # Shape: (batch, 1, 1, seq) - broadcasts to mask invalid keys for all queries/heads attn_bias = torch.zeros( (batch_size, 1, 1, seq_len), dtype=dtype, device=device, ) attn_bias = attn_bias.masked_fill( ~valid_mask.unsqueeze(1).unsqueeze(2), # (batch, 1, 1, seq) negative_inf, ) # Step 2: Add anchor-relative bias (optional) # Projects (batch, seq, num_attrs) → (batch, seq, num_heads) # Then reshapes to (batch, num_heads, 1, seq) for key-side bias anchor_bias_features = attention_bias_data.get("anchor_bias") if anchor_bias_features is not None: if self._anchor_pe_attention_bias_projection is None: raise ValueError("Anchor attention-bias projection is not initialized.") anchor_bias = self._anchor_pe_attention_bias_projection( anchor_bias_features.to(dtype) ) # (batch, seq, num_heads) anchor_bias = anchor_bias.permute(0, 2, 1).unsqueeze( 2 ) # (batch, num_heads, 1, seq) attn_bias = attn_bias + anchor_bias # Step 3: Add pairwise bias (optional) # Projects (batch, seq, seq, num_attrs) → (batch, seq, seq, num_heads) # Then reshapes to (batch, num_heads, seq, seq) pairwise_bias_features = attention_bias_data.get("pairwise_bias") if pairwise_bias_features is not None: if self._pairwise_pe_attention_bias_projection is None: raise ValueError( "Pairwise attention-bias projection is not initialized." ) pairwise_bias = self._pairwise_pe_attention_bias_projection( pairwise_bias_features.to(dtype) ) # (batch, seq, seq, num_heads) pairwise_bias = pairwise_bias.permute( 0, 3, 1, 2 ) # (batch, num_heads, seq, seq) attn_bias = attn_bias + pairwise_bias return attn_bias def _encode_and_readout( self, sequences: Tensor, valid_mask: Tensor, attn_bias: Optional[Tensor] = None, ) -> Tensor: """Process sequences through transformer layers and attention readout. Args: sequences: Input tensor of shape ``(batch_size, max_seq_len, hid_dim)``. valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``. attn_bias: Optional additive attention bias broadcastable to ``(batch_size, num_heads, seq, seq)``. Returns: Output embeddings of shape ``(batch_size, hid_dim)``. """ x = sequences * valid_mask.unsqueeze(-1).to(sequences.dtype) for encoder_layer in self._encoder_layers: x = encoder_layer(x, attn_bias=attn_bias, valid_mask=valid_mask) x = self._final_norm(x) x = x * valid_mask.unsqueeze(-1).to(x.dtype) # Readout: anchor (position 0) + attention-weighted neighbor aggregation anchor = x[:, 0, :].unsqueeze(1) # (batch, 1, hid_dim) neighbors = x[:, 1:, :] # (batch, seq-1, hid_dim) neighbor_valid_mask = valid_mask[:, 1:] seq_minus_one = neighbors.size(1) if seq_minus_one == 0: return anchor.squeeze(1) # Expand anchor to match neighbor dimension for concatenation anchor_expanded = anchor.expand(-1, seq_minus_one, -1) # Compute attention scores over neighbors readout_scores = self._readout_attention( torch.cat([anchor_expanded, neighbors], dim=-1) ) # (batch, seq-1, 1) readout_scores = readout_scores.masked_fill( ~neighbor_valid_mask.unsqueeze(-1), torch.finfo(readout_scores.dtype).min, ) readout_weights = F.softmax(readout_scores, dim=1) # (batch, seq-1, 1) readout_weights = torch.nan_to_num(readout_weights, nan=0.0) readout_weights = readout_weights * neighbor_valid_mask.unsqueeze(-1).to( readout_weights.dtype ) neighbor_aggregation = (neighbors * readout_weights).sum( dim=1, keepdim=True ) # (batch, 1, hid_dim) output = (anchor + neighbor_aggregation).squeeze(1) # (batch, hid_dim) return output