gigl.nn.graph_transformer#
Graph Transformer encoder for heterogeneous graphs.
Adapted from RelGT’s LocalModule (snap-stanford/relgt).
Converts heterogeneous graph data into fixed-length sequences via
heterodata_to_graph_transformer_input, processes through a stack of pre-norm
transformer encoder layers, then produces per-node embeddings via
attention-weighted neighbor readout.
Conforms to the same forward interface as HGT and SimpleHGN in
gigl.src.common.models.pyg.heterogeneous, making it a drop-in
replacement as the encoder in LinkPredictionGNN.
Classes#
Two-layer feed-forward network with configurable activation. |
|
Graph Transformer encoder for heterogeneous graphs. |
|
Pre-norm transformer encoder layer with multi-head self-attention. |
Module Contents#
- class gigl.nn.graph_transformer.FeedForwardNetwork(model_dim, feedforward_dim, dropout_rate=0.1, activation='gelu')[source]#
Bases:
torch.nn.ModuleTwo-layer feed-forward network with configurable activation.
Supports standard activations (GELU, ReLU, SiLU) and XGLU family (SwiGLU, GeGLU, ReGLU) which use a gating mechanism.
Note: This module does NOT include LayerNorm. Normalization should be applied externally (e.g., pre-norm in the transformer layer).
Adapted from RelGT’s FeedForwardNetwork.
- Parameters:
model_dim (int) – Model (input and output) dimension of the FFN.
feedforward_dim (int) – Inner dimension of the two-layer MLP.
dropout_rate (float) – Dropout probability applied after each linear layer.
activation (str) – Activation function name. Supported values: - Standard: “gelu” (default), “relu”, “silu”, “tanh” - XGLU family: “geglu”, “swiglu”, “reglu” XGLU activations use gating: activation(xW) * xV, which requires projecting to 2x feedforward_dim internally.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class gigl.nn.graph_transformer.GraphTransformerEncoder(node_type_to_feat_dim_map, edge_type_to_feat_dim_map, hid_dim, out_dim=128, num_layers=2, num_heads=2, max_seq_len=128, hop_distance=2, sequence_construction_method='khop', sequence_positional_encoding_type=None, dropout_rate=0.1, attention_dropout_rate=0.0, should_l2_normalize_embedding_layer_output=False, pe_attr_names=None, anchor_based_attention_bias_attr_names=None, anchor_based_input_attr_names=None, anchor_based_input_embedding_dict=None, pairwise_attention_bias_attr_names=None, feature_embedding_layer_dict=None, pe_integration_mode='concat', activation='gelu', feedforward_ratio=None, **kwargs)[source]#
Bases:
torch.nn.ModuleGraph Transformer encoder for heterogeneous graphs.
Converts heterogeneous graph data into fixed-length sequences via
heterodata_to_graph_transformer_input, processes through pre-norm transformer encoder layers, and produces per-node embeddings via attention-weighted neighbor readout (from RelGT’s LocalModule).Conforms to the same forward interface as
HGTandSimpleHGN, making it a drop-in encoder forLinkPredictionGNN.- Parameters:
node_type_to_feat_dim_map (dict[gigl.src.common.types.graph_data.NodeType, int]) – Dictionary mapping node types to their input feature dimensions.
edge_type_to_feat_dim_map (dict[gigl.src.common.types.graph_data.EdgeType, int]) – Dictionary mapping edge types to their feature dimensions. Accepted for interface conformance with
HGT/SimpleHGN; edge features are not used by the graph transformer.hid_dim (int) – Hidden dimension for transformer layers. All node types are projected to this dimension before processing.
out_dim (int) – Output embedding dimension.
num_layers (int) – Number of transformer encoder layers.
num_heads (int) – Number of attention heads per layer. Must evenly divide
hid_dim.max_seq_len (int) – Maximum sequence length for the graph-to-sequence transform. Neighborhoods are truncated to this length.
hop_distance (int) – Number of hops for neighborhood extraction in the graph-to-sequence transform when using
"khop"sequence construction.sequence_construction_method (Literal['khop', 'ppr']) – Sequence builder used to create tokens for each anchor.
"khop"expands the sampled graph by hop distance, while"ppr"consumes outgoing"ppr"edges sorted by weight.sequence_positional_encoding_type (Optional[str]) – Optional sequence-level positional encoding applied after sequence construction. Supported values are
Noneand"sinusoidal". Lower-cost future extensions could add learned absolute position embeddings here, while attention-level options like RoPE or ALiBi would require changes inside the attention block.dropout_rate (float) – Dropout probability for feed-forward layers.
attention_dropout_rate (float) – Dropout probability for attention weights.
should_l2_normalize_embedding_layer_output (bool) – Whether to L2 normalize output embeddings.
pe_attr_names (Optional[list[str]]) – List of node-level positional encoding attribute names. In
"concat"mode these are concatenated to sequence features. In"add"mode they are projected tohid_dimand added to node features before sequence construction.anchor_based_attention_bias_attr_names (Optional[list[str]]) – List of anchor-relative feature names used as additive attention bias for sequence keys. Sparse graph-level attributes are looked up from
dataand the reserved name"ppr_weight"resolves to PPR edge weights in PPR mode. Example:['hop_distance', 'ppr_weight']wherehop_distanceis a sparse matrix attribute ondataandppr_weightis extracted from PPR edge weights.anchor_based_input_attr_names (Optional[list[str]]) – List of anchor-relative attribute names used as token-aligned input features. Sparse graph-level attributes are looked up from
dataand"ppr_weight"resolves to PPR edge weights in PPR mode. These are projected tohid_dimand added to the sequence tokens after sequence construction. Example:['hop_distance', 'ppr_weight']for continuous features, or['hop_distance']whenhop_distancewill be embedded viaanchor_based_input_embedding_dict.anchor_based_input_embedding_dict (Optional[torch.nn.ModuleDict]) – Optional ModuleDict mapping a subset of
anchor_based_input_attr_namesto per-attribute embedding layers. These attributes are treated as discrete indices and their embedded contributions are added to the sequence tokens. Padding is masked out using the sequence valid mask. Example:nn.ModuleDict({'hop_distance': nn.Embedding(10, hid_dim)})to embed hop distances 0-9 intohid_dim-dimensional vectors. The embedding output dimension must matchhid_dim.pairwise_attention_bias_attr_names (Optional[list[str]]) – List of pairwise feature names used as additive attention bias. These must correspond to sparse graph-level attributes on
data.feature_embedding_layer_dict (Optional[torch.nn.ModuleDict]) – Optional ModuleDict mapping node types to feature embedding layers. If provided, these are applied to node features before node projection. (default: None)
pe_integration_mode (Literal['concat', 'add']) – How to fuse positional encodings into the model input.
"concat"preserves the current behavior by concatenating node-level PE to token features."add"uses node-level additive PE before sequence construction and attention bias for relative encodings.activation (str) – Activation function for the feed-forward network in each transformer layer. Supported values: - Standard: “gelu” (default), “relu”, “silu”, “tanh” - XGLU family: “geglu”, “swiglu”, “reglu” XGLU activations use gating: activation(xW) * xV.
feedforward_ratio (Optional[float]) – Ratio of feedforward dimension to hidden dimension (feedforward_dim = hid_dim * feedforward_ratio). If None (default), uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU’s gating doubles the effective parameters, so a smaller ratio maintains similar parameter count.
kwargs (object)
Notes
This encoder uses
nn.LazyLinearfor node-level PE fusion. If you wrap it withDistributedDataParallel, run one representative no-grad forward first, passinganchor_node_ids/anchor_node_typefor the graph-transformer path, or load a checkpoint before DDP so all ranks see initialized weights.- TODO: Pairwise relative bias is currently materialized densely for the selected
sequence. That is fine for moderate
max_seq_len, but a chunked or sparse LPFormer-style path is still future work for larger sequences.
Example
>>> from gigl.nn.graph_transformer import GraphTransformerEncoder >>> encoder = GraphTransformerEncoder( ... node_type_to_feat_dim_map={NodeType("user"): 64, NodeType("item"): 32}, ... edge_type_to_feat_dim_map={}, ... hid_dim=128, ... out_dim=64, ... num_layers=2, ... num_heads=4, ... ) >>> embeddings = encoder(data, anchor_node_type=NodeType("user"), device=device)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data, anchor_node_type=None, anchor_node_ids=None, device=None)[source]#
Run the forward pass of the Graph Transformer encoder.
- Parameters:
data (torch_geometric.data.hetero_data.HeteroData) – Input HeteroData object with node features (
x_dict) and edge indices (edge_index_dict).anchor_node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – Node type for which to compute embeddings. If None, uses the first node type in data.
anchor_node_ids (Optional[torch.Tensor]) – Optional tensor of local node indices within anchor_node_type to use as anchors. If None, uses the first batch_size nodes (seed nodes from neighbor sampling).
device (Optional[torch.device]) – Torch device for output tensors. If None, inferred from data.
- Returns:
Embeddings tensor of shape
(num_anchor_nodes, out_dim).- Return type:
torch.Tensor
- class gigl.nn.graph_transformer.GraphTransformerEncoderLayer(model_dim, num_heads, feedforward_dim, dropout_rate=0.1, attention_dropout_rate=0.0, activation='gelu')[source]#
Bases:
torch.nn.ModulePre-norm transformer encoder layer with multi-head self-attention.
Uses
F.scaled_dot_product_attentionwhich automatically selects the most efficient attention implementation (flash, memory-efficient, or math-based) based on input properties and hardware.Adapted from RelGT’s EncoderLayer.
- Parameters:
model_dim (int) – Model dimension (d_model).
num_heads (int) – Number of attention heads. Must evenly divide model_dim.
feedforward_dim (int) – Inner dimension of the feed-forward network.
dropout_rate (float) – Dropout probability for feed-forward layers.
attention_dropout_rate (float) – Dropout probability for attention weights.
activation (str) – Activation function for the feed-forward network. Supported values: “gelu” (default), “relu”, “silu”, “tanh”, “geglu”, “swiglu”, “reglu”.
- Raises:
ValueError – If model_dim is not divisible by num_heads.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, attn_bias=None, valid_mask=None)[source]#
Forward pass.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch, seq, model_dim).attn_bias (Optional[torch.Tensor]) – Optional attention bias of shape
(batch, num_heads, seq, seq)or broadcastable. Added as an additive mask to attention scores.valid_mask (Optional[torch.Tensor]) – Optional boolean tensor of shape
(batch, seq)used to zero out padded token states after each residual block.
- Returns:
Output tensor of shape
(batch, seq, model_dim).- Return type:
torch.Tensor