gigl.nn.graph_transformer#

Graph Transformer encoder for heterogeneous graphs.

Adapted from RelGT’s LocalModule (snap-stanford/relgt). Converts heterogeneous graph data into fixed-length sequences via heterodata_to_graph_transformer_input, processes through a stack of pre-norm transformer encoder layers, then produces per-node embeddings via attention-weighted neighbor readout.

Conforms to the same forward interface as HGT and SimpleHGN in gigl.src.common.models.pyg.heterogeneous, making it a drop-in replacement as the encoder in LinkPredictionGNN.

Classes#

FeedForwardNetwork

Two-layer feed-forward network with configurable activation.

GraphTransformerEncoder

Graph Transformer encoder for heterogeneous graphs.

GraphTransformerEncoderLayer

Pre-norm transformer encoder layer with multi-head self-attention.

Module Contents#

class gigl.nn.graph_transformer.FeedForwardNetwork(model_dim, feedforward_dim, dropout_rate=0.1, activation='gelu')[source]#

Bases: torch.nn.Module

Two-layer feed-forward network with configurable activation.

Supports standard activations (GELU, ReLU, SiLU) and XGLU family (SwiGLU, GeGLU, ReGLU) which use a gating mechanism.

Note: This module does NOT include LayerNorm. Normalization should be applied externally (e.g., pre-norm in the transformer layer).

Adapted from RelGT’s FeedForwardNetwork.

Parameters:
  • model_dim (int) – Model (input and output) dimension of the FFN.

  • feedforward_dim (int) – Inner dimension of the two-layer MLP.

  • dropout_rate (float) – Dropout probability applied after each linear layer.

  • activation (str) – Activation function name. Supported values: - Standard: “gelu” (default), “relu”, “silu”, “tanh” - XGLU family: “geglu”, “swiglu”, “reglu” XGLU activations use gating: activation(xW) * xV, which requires projecting to 2x feedforward_dim internally.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Forward pass.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch, seq, model_dim).

Returns:

Output tensor of shape (batch, seq, model_dim).

Return type:

torch.Tensor

reset_parameters()[source]#

Reinitialize all learnable parameters.

Return type:

None

class gigl.nn.graph_transformer.GraphTransformerEncoder(node_type_to_feat_dim_map, edge_type_to_feat_dim_map, hid_dim, out_dim=128, num_layers=2, num_heads=2, max_seq_len=128, hop_distance=2, sequence_construction_method='khop', sequence_positional_encoding_type=None, dropout_rate=0.1, attention_dropout_rate=0.0, should_l2_normalize_embedding_layer_output=False, pe_attr_names=None, anchor_based_attention_bias_attr_names=None, anchor_based_input_attr_names=None, anchor_based_input_embedding_dict=None, pairwise_attention_bias_attr_names=None, feature_embedding_layer_dict=None, pe_integration_mode='concat', activation='gelu', feedforward_ratio=None, **kwargs)[source]#

Bases: torch.nn.Module

Graph Transformer encoder for heterogeneous graphs.

Converts heterogeneous graph data into fixed-length sequences via heterodata_to_graph_transformer_input, processes through pre-norm transformer encoder layers, and produces per-node embeddings via attention-weighted neighbor readout (from RelGT’s LocalModule).

Conforms to the same forward interface as HGT and SimpleHGN, making it a drop-in encoder for LinkPredictionGNN.

Parameters:
  • node_type_to_feat_dim_map (dict[gigl.src.common.types.graph_data.NodeType, int]) – Dictionary mapping node types to their input feature dimensions.

  • edge_type_to_feat_dim_map (dict[gigl.src.common.types.graph_data.EdgeType, int]) – Dictionary mapping edge types to their feature dimensions. Accepted for interface conformance with HGT/SimpleHGN; edge features are not used by the graph transformer.

  • hid_dim (int) – Hidden dimension for transformer layers. All node types are projected to this dimension before processing.

  • out_dim (int) – Output embedding dimension.

  • num_layers (int) – Number of transformer encoder layers.

  • num_heads (int) – Number of attention heads per layer. Must evenly divide hid_dim.

  • max_seq_len (int) – Maximum sequence length for the graph-to-sequence transform. Neighborhoods are truncated to this length.

  • hop_distance (int) – Number of hops for neighborhood extraction in the graph-to-sequence transform when using "khop" sequence construction.

  • sequence_construction_method (Literal['khop', 'ppr']) – Sequence builder used to create tokens for each anchor. "khop" expands the sampled graph by hop distance, while "ppr" consumes outgoing "ppr" edges sorted by weight.

  • sequence_positional_encoding_type (Optional[str]) – Optional sequence-level positional encoding applied after sequence construction. Supported values are None and "sinusoidal". Lower-cost future extensions could add learned absolute position embeddings here, while attention-level options like RoPE or ALiBi would require changes inside the attention block.

  • dropout_rate (float) – Dropout probability for feed-forward layers.

  • attention_dropout_rate (float) – Dropout probability for attention weights.

  • should_l2_normalize_embedding_layer_output (bool) – Whether to L2 normalize output embeddings.

  • pe_attr_names (Optional[list[str]]) – List of node-level positional encoding attribute names. In "concat" mode these are concatenated to sequence features. In "add" mode they are projected to hid_dim and added to node features before sequence construction.

  • anchor_based_attention_bias_attr_names (Optional[list[str]]) – List of anchor-relative feature names used as additive attention bias for sequence keys. Sparse graph-level attributes are looked up from data and the reserved name "ppr_weight" resolves to PPR edge weights in PPR mode. Example: ['hop_distance', 'ppr_weight'] where hop_distance is a sparse matrix attribute on data and ppr_weight is extracted from PPR edge weights.

  • anchor_based_input_attr_names (Optional[list[str]]) – List of anchor-relative attribute names used as token-aligned input features. Sparse graph-level attributes are looked up from data and "ppr_weight" resolves to PPR edge weights in PPR mode. These are projected to hid_dim and added to the sequence tokens after sequence construction. Example: ['hop_distance', 'ppr_weight'] for continuous features, or ['hop_distance'] when hop_distance will be embedded via anchor_based_input_embedding_dict.

  • anchor_based_input_embedding_dict (Optional[torch.nn.ModuleDict]) – Optional ModuleDict mapping a subset of anchor_based_input_attr_names to per-attribute embedding layers. These attributes are treated as discrete indices and their embedded contributions are added to the sequence tokens. Padding is masked out using the sequence valid mask. Example: nn.ModuleDict({'hop_distance': nn.Embedding(10, hid_dim)}) to embed hop distances 0-9 into hid_dim-dimensional vectors. The embedding output dimension must match hid_dim.

  • pairwise_attention_bias_attr_names (Optional[list[str]]) – List of pairwise feature names used as additive attention bias. These must correspond to sparse graph-level attributes on data.

  • feature_embedding_layer_dict (Optional[torch.nn.ModuleDict]) – Optional ModuleDict mapping node types to feature embedding layers. If provided, these are applied to node features before node projection. (default: None)

  • pe_integration_mode (Literal['concat', 'add']) – How to fuse positional encodings into the model input. "concat" preserves the current behavior by concatenating node-level PE to token features. "add" uses node-level additive PE before sequence construction and attention bias for relative encodings.

  • activation (str) – Activation function for the feed-forward network in each transformer layer. Supported values: - Standard: “gelu” (default), “relu”, “silu”, “tanh” - XGLU family: “geglu”, “swiglu”, “reglu” XGLU activations use gating: activation(xW) * xV.

  • feedforward_ratio (Optional[float]) – Ratio of feedforward dimension to hidden dimension (feedforward_dim = hid_dim * feedforward_ratio). If None (default), uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU’s gating doubles the effective parameters, so a smaller ratio maintains similar parameter count.

  • kwargs (object)

Notes

This encoder uses nn.LazyLinear for node-level PE fusion. If you wrap it with DistributedDataParallel, run one representative no-grad forward first, passing anchor_node_ids/anchor_node_type for the graph-transformer path, or load a checkpoint before DDP so all ranks see initialized weights.

TODO: Pairwise relative bias is currently materialized densely for the selected

sequence. That is fine for moderate max_seq_len, but a chunked or sparse LPFormer-style path is still future work for larger sequences.

Example

>>> from gigl.nn.graph_transformer import GraphTransformerEncoder
>>> encoder = GraphTransformerEncoder(
...     node_type_to_feat_dim_map={NodeType("user"): 64, NodeType("item"): 32},
...     edge_type_to_feat_dim_map={},
...     hid_dim=128,
...     out_dim=64,
...     num_layers=2,
...     num_heads=4,
... )
>>> embeddings = encoder(data, anchor_node_type=NodeType("user"), device=device)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data, anchor_node_type=None, anchor_node_ids=None, device=None)[source]#

Run the forward pass of the Graph Transformer encoder.

Parameters:
  • data (torch_geometric.data.hetero_data.HeteroData) – Input HeteroData object with node features (x_dict) and edge indices (edge_index_dict).

  • anchor_node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – Node type for which to compute embeddings. If None, uses the first node type in data.

  • anchor_node_ids (Optional[torch.Tensor]) – Optional tensor of local node indices within anchor_node_type to use as anchors. If None, uses the first batch_size nodes (seed nodes from neighbor sampling).

  • device (Optional[torch.device]) – Torch device for output tensors. If None, inferred from data.

Returns:

Embeddings tensor of shape (num_anchor_nodes, out_dim).

Return type:

torch.Tensor

class gigl.nn.graph_transformer.GraphTransformerEncoderLayer(model_dim, num_heads, feedforward_dim, dropout_rate=0.1, attention_dropout_rate=0.0, activation='gelu')[source]#

Bases: torch.nn.Module

Pre-norm transformer encoder layer with multi-head self-attention.

Uses F.scaled_dot_product_attention which automatically selects the most efficient attention implementation (flash, memory-efficient, or math-based) based on input properties and hardware.

Adapted from RelGT’s EncoderLayer.

Parameters:
  • model_dim (int) – Model dimension (d_model).

  • num_heads (int) – Number of attention heads. Must evenly divide model_dim.

  • feedforward_dim (int) – Inner dimension of the feed-forward network.

  • dropout_rate (float) – Dropout probability for feed-forward layers.

  • attention_dropout_rate (float) – Dropout probability for attention weights.

  • activation (str) – Activation function for the feed-forward network. Supported values: “gelu” (default), “relu”, “silu”, “tanh”, “geglu”, “swiglu”, “reglu”.

Raises:

ValueError – If model_dim is not divisible by num_heads.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, attn_bias=None, valid_mask=None)[source]#

Forward pass.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch, seq, model_dim).

  • attn_bias (Optional[torch.Tensor]) – Optional attention bias of shape (batch, num_heads, seq, seq) or broadcastable. Added as an additive mask to attention scores.

  • valid_mask (Optional[torch.Tensor]) – Optional boolean tensor of shape (batch, seq) used to zero out padded token states after each residual block.

Returns:

Output tensor of shape (batch, seq, model_dim).

Return type:

torch.Tensor

reset_parameters()[source]#

Reinitialize all learnable parameters.

Return type:

None