gigl.nn.graph_transformer#
Graph Transformer encoder for heterogeneous graphs.
Adapted from RelGT’s LocalModule (snap-stanford/relgt).
Converts heterogeneous graph data into fixed-length sequences via
heterodata_to_graph_transformer_input, processes through a stack of pre-norm
transformer encoder layers, then produces per-node embeddings via configurable
anchor or anchor-neighbor readout.
Conforms to the same forward interface as HGT and SimpleHGN in
gigl.src.common.models.pyg.heterogeneous, making it a drop-in
replacement as the encoder in LinkPredictionGNN.
Classes#
Two-layer feed-forward network with configurable activation. |
|
Graph Transformer encoder for heterogeneous graphs. |
|
Pre-norm transformer encoder layer with multi-head self-attention. |
Module Contents#
- class gigl.nn.graph_transformer.FeedForwardNetwork(model_dim, feedforward_dim, dropout_rate=0.1, activation='gelu')[source]#
Bases:
torch.nn.ModuleTwo-layer feed-forward network with configurable activation.
Supports standard activations (GELU, ReLU, SiLU) and XGLU family (SwiGLU, GeGLU, ReGLU) which use a gating mechanism.
Note: This module does NOT include LayerNorm. Normalization should be applied externally (e.g., pre-norm in the transformer layer).
Adapted from RelGT’s FeedForwardNetwork.
- Parameters:
model_dim (int) – Model (input and output) dimension of the FFN.
feedforward_dim (int) – Inner dimension of the two-layer MLP.
dropout_rate (float) – Dropout probability applied after each linear layer.
activation (str) – Activation function name. Supported values: - Standard: “gelu” (default), “relu”, “silu”, “tanh” - XGLU family: “geglu”, “swiglu”, “reglu” XGLU activations use gating: activation(xW) * xV, which requires projecting to 2x feedforward_dim internally.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class gigl.nn.graph_transformer.GraphTransformerEncoder(node_type_to_feat_dim_map, edge_type_to_feat_dim_map, hid_dim, out_dim=128, num_layers=2, num_heads=2, max_seq_len=128, hop_distance=2, sequence_construction_method='khop', sequence_positional_encoding_type=None, readout_mode='anchor_neighbor_attention', dropout_rate=0.1, attention_dropout_rate=0.0, should_l2_normalize_embedding_layer_output=False, pe_attr_names=None, anchor_based_attention_bias_attr_names=None, anchor_based_input_attr_names=None, anchor_based_input_embedding_dict=None, pairwise_attention_bias_attr_names=None, sampling_direction='out', feature_embedding_layer_dict=None, pe_integration_mode='concat', activation='gelu', feedforward_ratio=None, relation_attention_mode='none', **kwargs)[source]#
Bases:
torch.nn.ModuleGraph Transformer encoder for heterogeneous graphs.
Converts heterogeneous graph data into fixed-length sequences via
heterodata_to_graph_transformer_input, processes through pre-norm transformer encoder layers, and produces per-node embeddings via configurable anchor or anchor-neighbor readout.Conforms to the same forward interface as
HGTandSimpleHGN, making it a drop-in encoder forLinkPredictionGNN.- Parameters:
node_type_to_feat_dim_map (dict[gigl.src.common.types.graph_data.NodeType, int]) – Dictionary mapping node types to their input feature dimensions.
edge_type_to_feat_dim_map (dict[gigl.src.common.types.graph_data.EdgeType, int]) – Dictionary mapping edge types to their feature dimensions. Accepted for interface conformance with
HGT/SimpleHGN; edge features are not used by the graph transformer.hid_dim (int) – Hidden dimension for transformer layers. All node types are projected to this dimension before processing.
out_dim (int) – Output embedding dimension.
num_layers (int) – Number of transformer encoder layers.
num_heads (int) – Number of attention heads per layer. Must evenly divide
hid_dim.max_seq_len (int) – Maximum sequence length for the graph-to-sequence transform. Neighborhoods are truncated to this length.
hop_distance (int) – Number of hops for neighborhood extraction in the graph-to-sequence transform when using
"khop"sequence construction.sequence_construction_method (Literal['khop', 'ppr']) – Sequence builder used to create tokens for each anchor.
"khop"expands the sampled graph by hop distance, while"ppr"consumes outgoing"ppr"edges sorted by weight.sequence_positional_encoding_type (Optional[str]) – Optional sequence-level positional encoding applied after sequence construction. Supported values are
Noneand"sinusoidal". Lower-cost future extensions could add learned absolute position embeddings here, while attention-level options like RoPE or ALiBi would require changes inside the attention block.readout_mode (Literal['anchor_neighbor_attention', 'anchor_only']) – Strategy used to turn encoded sequence tokens into one embedding per anchor.
"anchor_neighbor_attention"preserves the existing behavior by adding the encoded anchor token to an attention-weighted neighbor aggregation."anchor_only"returns only the encoded anchor token.dropout_rate (float) – Dropout probability for feed-forward layers.
attention_dropout_rate (float) – Dropout probability for attention weights.
should_l2_normalize_embedding_layer_output (bool) – Whether to L2 normalize output embeddings.
pe_attr_names (Optional[list[str]]) – List of node-level positional encoding attribute names. In
"concat"mode these are concatenated to sequence features. In"add"mode they are projected tohid_dimand added to node features before sequence construction.anchor_based_attention_bias_attr_names (Optional[list[str]]) – List of anchor-relative feature names used as additive attention bias for sequence keys. Sparse graph-level attributes are looked up from
dataand the reserved name"ppr_weight"resolves to PPR edge weights in PPR mode. Example:['hop_distance', 'ppr_weight']wherehop_distanceis a sparse matrix attribute ondataandppr_weightis extracted from PPR edge weights.anchor_based_input_attr_names (Optional[list[str]]) – List of anchor-relative attribute names used as token-aligned input features. Sparse graph-level attributes are looked up from
dataand"ppr_weight"resolves to PPR edge weights in PPR mode. These are projected tohid_dimand added to the sequence tokens after sequence construction. Example:['hop_distance', 'ppr_weight']for continuous features, or['hop_distance']whenhop_distancewill be embedded viaanchor_based_input_embedding_dict.anchor_based_input_embedding_dict (Optional[torch.nn.ModuleDict]) – Optional ModuleDict mapping a subset of
anchor_based_input_attr_namesto per-attribute embedding layers. These attributes are treated as discrete indices and their embedded contributions are added to the sequence tokens. Padding is masked out using the sequence valid mask. Example:nn.ModuleDict({'hop_distance': nn.Embedding(10, hid_dim)})to embed hop distances 0-9 intohid_dim-dimensional vectors. The embedding output dimension must matchhid_dim.pairwise_attention_bias_attr_names (Optional[list[str]]) – List of pairwise feature names used as additive attention bias. These must correspond to sparse graph-level attributes on
data.sampling_direction (Literal['in', 'out']) – Direction used for sequence token construction.
"out"preserves the existing k-hop reachability expansion."in"expands over reversed edges and is supported only whensequence_construction_method="khop". Directed relative encodings such as"hop_distance"should be computed with the same direction.feature_embedding_layer_dict (Optional[torch.nn.ModuleDict]) – Optional ModuleDict mapping node types to feature embedding layers. If provided, these are applied to node features before node projection. (default: None)
pe_integration_mode (Literal['concat', 'add']) – How to fuse positional encodings into the model input.
"concat"preserves the current behavior by concatenating node-level PE to token features."add"uses node-level additive PE before sequence construction and attention bias for relative encodings.activation (str) – Activation function for the feed-forward network in each transformer layer. Supported values: - Standard: “gelu” (default), “relu”, “silu”, “tanh” - XGLU family: “geglu”, “swiglu”, “reglu” XGLU activations use gating: activation(xW) * xV.
feedforward_ratio (Optional[float]) – Ratio of feedforward dimension to hidden dimension (feedforward_dim = hid_dim * feedforward_ratio). If None (default), uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU’s gating doubles the effective parameters, so a smaller ratio maintains similar parameter count.
relation_attention_mode (Literal['none', 'edge_type_bilinear', 'edge_type_hgt']) – Optional relation-aware augmentation for attention scores.
"none"preserves the current transformer path."edge_type_bilinear"adds a learned per-edge-type bilinear score term for sampled directed edges."edge_type_hgt"replaces base query/key scores on relation edges with an HGT-style relation transform and relation prior.kwargs (object)
Notes
This encoder uses
nn.LazyLinearfor node-level PE fusion. If you wrap it withDistributedDataParallel, run one representative no-grad forward first, passinganchor_node_ids/anchor_node_typefor the graph-transformer path, or load a checkpoint before DDP so all ranks see initialized weights.- TODO: Pairwise relative bias is currently materialized densely for the selected
sequence. That is fine for moderate
max_seq_len, but a chunked or sparse LPFormer-style path is still future work for larger sequences.
Example
>>> from gigl.nn.graph_transformer import GraphTransformerEncoder >>> encoder = GraphTransformerEncoder( ... node_type_to_feat_dim_map={NodeType("user"): 64, NodeType("item"): 32}, ... edge_type_to_feat_dim_map={}, ... hid_dim=128, ... out_dim=64, ... num_layers=2, ... num_heads=4, ... ) >>> embeddings = encoder(data, anchor_node_type=NodeType("user"), device=device)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data, anchor_node_type=None, anchor_node_ids=None, device=None)[source]#
Run the forward pass of the Graph Transformer encoder.
- Parameters:
data (torch_geometric.data.hetero_data.HeteroData) – Input HeteroData object with node features (
x_dict) and edge indices (edge_index_dict).anchor_node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – Node type for which to compute embeddings. If None, uses the first node type in data.
anchor_node_ids (Optional[torch.Tensor]) – Optional tensor of local node indices within anchor_node_type to use as anchors. If None, uses the first batch_size nodes (seed nodes from neighbor sampling).
device (Optional[torch.device]) – Torch device for output tensors. If None, inferred from data.
- Returns:
Embeddings tensor of shape
(num_anchor_nodes, out_dim).- Return type:
torch.Tensor
- class gigl.nn.graph_transformer.GraphTransformerEncoderLayer(model_dim, num_heads, feedforward_dim, dropout_rate=0.1, attention_dropout_rate=0.0, activation='gelu', relation_attention_mode='none', num_relations=0)[source]#
Bases:
torch.nn.ModulePre-norm transformer encoder layer with multi-head self-attention.
Uses
F.scaled_dot_product_attentionwhich automatically selects the most efficient attention implementation (flash, memory-efficient, or math-based) based on input properties and hardware.Adapted from RelGT’s EncoderLayer.
- Parameters:
model_dim (int) – Model dimension (d_model).
num_heads (int) – Number of attention heads. Must evenly divide model_dim.
feedforward_dim (int) – Inner dimension of the feed-forward network.
dropout_rate (float) – Dropout probability for feed-forward layers.
attention_dropout_rate (float) – Dropout probability for attention weights.
activation (str) – Activation function for the feed-forward network. Supported values: “gelu” (default), “relu”, “silu”, “tanh”, “geglu”, “swiglu”, “reglu”.
relation_attention_mode (Literal['none', 'edge_type_bilinear', 'edge_type_hgt']) – Optional relation-aware augmentation strategy for attention scores.
"none"preserves the default shared self-attention path."edge_type_bilinear"adds a learned per-edge-type bilinear term for sampled directed graph edges. This changes attention weights, not value/message content."edge_type_hgt"replaces the base query/key score on relation edges with an HGT-style relation transform and relation prior.num_relations (int) – Number of relation channels expected in
pairwise_relation_indiceswhen relation-aware attention is enabled.
- Raises:
ValueError – If model_dim is not divisible by num_heads.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, attn_bias=None, valid_mask=None, pairwise_relation_indices=None)[source]#
Forward pass.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch, seq, model_dim).attn_bias (Optional[torch.Tensor]) – Optional attention bias of shape
(batch, num_heads, seq, seq)or broadcastable. Added as an additive mask to attention scores.valid_mask (Optional[torch.Tensor]) – Optional boolean tensor of shape
(batch, seq)used to zero out padded token states after each residual block.pairwise_relation_indices (Optional[torch.Tensor]) – Optional long tensor of shape
(num_relation_edges, 4)with sparse(batch_idx, query_pos, key_pos, relation_idx)coordinates.
- Returns:
Output tensor of shape
(batch, seq, model_dim).- Return type:
torch.Tensor