gigl.nn#

GiGL NN Module

Submodules#

Classes#

FeedForwardNetwork

Two-layer feed-forward network with configurable activation.

GraphTransformerEncoder

Graph Transformer encoder for heterogeneous graphs.

GraphTransformerEncoderLayer

Pre-norm transformer encoder layer with multi-head self-attention.

LightGCN

LightGCN model with TorchRec integration for distributed ID embeddings.

LinkPredictionGNN

Link Prediction GNN model for both homogeneous and heterogeneous use cases

RetrievalLoss

A loss layer built on top of the tensorflow_recommenders implementation.

Package Contents#

class gigl.nn.FeedForwardNetwork(model_dim, feedforward_dim, dropout_rate=0.1, activation='gelu')[source]#

Bases: torch.nn.Module

Two-layer feed-forward network with configurable activation.

Supports standard activations (GELU, ReLU, SiLU) and XGLU family (SwiGLU, GeGLU, ReGLU) which use a gating mechanism.

Note: This module does NOT include LayerNorm. Normalization should be applied externally (e.g., pre-norm in the transformer layer).

Adapted from RelGT’s FeedForwardNetwork.

Parameters:
  • model_dim (int) – Model (input and output) dimension of the FFN.

  • feedforward_dim (int) – Inner dimension of the two-layer MLP.

  • dropout_rate (float) – Dropout probability applied after each linear layer.

  • activation (str) – Activation function name. Supported values: - Standard: “gelu” (default), “relu”, “silu”, “tanh” - XGLU family: “geglu”, “swiglu”, “reglu” XGLU activations use gating: activation(xW) * xV, which requires projecting to 2x feedforward_dim internally.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Forward pass.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch, seq, model_dim).

Returns:

Output tensor of shape (batch, seq, model_dim).

Return type:

torch.Tensor

reset_parameters()[source]#

Reinitialize all learnable parameters.

Return type:

None

class gigl.nn.GraphTransformerEncoder(node_type_to_feat_dim_map, edge_type_to_feat_dim_map, hid_dim, out_dim=128, num_layers=2, num_heads=2, max_seq_len=128, hop_distance=2, sequence_construction_method='khop', sequence_positional_encoding_type=None, dropout_rate=0.1, attention_dropout_rate=0.0, should_l2_normalize_embedding_layer_output=False, pe_attr_names=None, anchor_based_attention_bias_attr_names=None, anchor_based_input_attr_names=None, anchor_based_input_embedding_dict=None, pairwise_attention_bias_attr_names=None, feature_embedding_layer_dict=None, pe_integration_mode='concat', activation='gelu', feedforward_ratio=None, **kwargs)[source]#

Bases: torch.nn.Module

Graph Transformer encoder for heterogeneous graphs.

Converts heterogeneous graph data into fixed-length sequences via heterodata_to_graph_transformer_input, processes through pre-norm transformer encoder layers, and produces per-node embeddings via attention-weighted neighbor readout (from RelGT’s LocalModule).

Conforms to the same forward interface as HGT and SimpleHGN, making it a drop-in encoder for LinkPredictionGNN.

Parameters:
  • node_type_to_feat_dim_map (dict[gigl.src.common.types.graph_data.NodeType, int]) – Dictionary mapping node types to their input feature dimensions.

  • edge_type_to_feat_dim_map (dict[gigl.src.common.types.graph_data.EdgeType, int]) – Dictionary mapping edge types to their feature dimensions. Accepted for interface conformance with HGT/SimpleHGN; edge features are not used by the graph transformer.

  • hid_dim (int) – Hidden dimension for transformer layers. All node types are projected to this dimension before processing.

  • out_dim (int) – Output embedding dimension.

  • num_layers (int) – Number of transformer encoder layers.

  • num_heads (int) – Number of attention heads per layer. Must evenly divide hid_dim.

  • max_seq_len (int) – Maximum sequence length for the graph-to-sequence transform. Neighborhoods are truncated to this length.

  • hop_distance (int) – Number of hops for neighborhood extraction in the graph-to-sequence transform when using "khop" sequence construction.

  • sequence_construction_method (Literal['khop', 'ppr']) – Sequence builder used to create tokens for each anchor. "khop" expands the sampled graph by hop distance, while "ppr" consumes outgoing "ppr" edges sorted by weight.

  • sequence_positional_encoding_type (Optional[str]) – Optional sequence-level positional encoding applied after sequence construction. Supported values are None and "sinusoidal". Lower-cost future extensions could add learned absolute position embeddings here, while attention-level options like RoPE or ALiBi would require changes inside the attention block.

  • dropout_rate (float) – Dropout probability for feed-forward layers.

  • attention_dropout_rate (float) – Dropout probability for attention weights.

  • should_l2_normalize_embedding_layer_output (bool) – Whether to L2 normalize output embeddings.

  • pe_attr_names (Optional[list[str]]) – List of node-level positional encoding attribute names. In "concat" mode these are concatenated to sequence features. In "add" mode they are projected to hid_dim and added to node features before sequence construction.

  • anchor_based_attention_bias_attr_names (Optional[list[str]]) – List of anchor-relative feature names used as additive attention bias for sequence keys. Sparse graph-level attributes are looked up from data and the reserved name "ppr_weight" resolves to PPR edge weights in PPR mode. Example: ['hop_distance', 'ppr_weight'] where hop_distance is a sparse matrix attribute on data and ppr_weight is extracted from PPR edge weights.

  • anchor_based_input_attr_names (Optional[list[str]]) – List of anchor-relative attribute names used as token-aligned input features. Sparse graph-level attributes are looked up from data and "ppr_weight" resolves to PPR edge weights in PPR mode. These are projected to hid_dim and added to the sequence tokens after sequence construction. Example: ['hop_distance', 'ppr_weight'] for continuous features, or ['hop_distance'] when hop_distance will be embedded via anchor_based_input_embedding_dict.

  • anchor_based_input_embedding_dict (Optional[torch.nn.ModuleDict]) – Optional ModuleDict mapping a subset of anchor_based_input_attr_names to per-attribute embedding layers. These attributes are treated as discrete indices and their embedded contributions are added to the sequence tokens. Padding is masked out using the sequence valid mask. Example: nn.ModuleDict({'hop_distance': nn.Embedding(10, hid_dim)}) to embed hop distances 0-9 into hid_dim-dimensional vectors. The embedding output dimension must match hid_dim.

  • pairwise_attention_bias_attr_names (Optional[list[str]]) – List of pairwise feature names used as additive attention bias. These must correspond to sparse graph-level attributes on data.

  • feature_embedding_layer_dict (Optional[torch.nn.ModuleDict]) – Optional ModuleDict mapping node types to feature embedding layers. If provided, these are applied to node features before node projection. (default: None)

  • pe_integration_mode (Literal['concat', 'add']) – How to fuse positional encodings into the model input. "concat" preserves the current behavior by concatenating node-level PE to token features. "add" uses node-level additive PE before sequence construction and attention bias for relative encodings.

  • activation (str) – Activation function for the feed-forward network in each transformer layer. Supported values: - Standard: “gelu” (default), “relu”, “silu”, “tanh” - XGLU family: “geglu”, “swiglu”, “reglu” XGLU activations use gating: activation(xW) * xV.

  • feedforward_ratio (Optional[float]) – Ratio of feedforward dimension to hidden dimension (feedforward_dim = hid_dim * feedforward_ratio). If None (default), uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU’s gating doubles the effective parameters, so a smaller ratio maintains similar parameter count.

  • kwargs (object)

Notes

This encoder uses nn.LazyLinear for node-level PE fusion. If you wrap it with DistributedDataParallel, run one representative no-grad forward first, passing anchor_node_ids/anchor_node_type for the graph-transformer path, or load a checkpoint before DDP so all ranks see initialized weights.

TODO: Pairwise relative bias is currently materialized densely for the selected

sequence. That is fine for moderate max_seq_len, but a chunked or sparse LPFormer-style path is still future work for larger sequences.

Example

>>> from gigl.nn.graph_transformer import GraphTransformerEncoder
>>> encoder = GraphTransformerEncoder(
...     node_type_to_feat_dim_map={NodeType("user"): 64, NodeType("item"): 32},
...     edge_type_to_feat_dim_map={},
...     hid_dim=128,
...     out_dim=64,
...     num_layers=2,
...     num_heads=4,
... )
>>> embeddings = encoder(data, anchor_node_type=NodeType("user"), device=device)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data, anchor_node_type=None, anchor_node_ids=None, device=None)[source]#

Run the forward pass of the Graph Transformer encoder.

Parameters:
  • data (torch_geometric.data.hetero_data.HeteroData) – Input HeteroData object with node features (x_dict) and edge indices (edge_index_dict).

  • anchor_node_type (Optional[gigl.src.common.types.graph_data.NodeType]) – Node type for which to compute embeddings. If None, uses the first node type in data.

  • anchor_node_ids (Optional[torch.Tensor]) – Optional tensor of local node indices within anchor_node_type to use as anchors. If None, uses the first batch_size nodes (seed nodes from neighbor sampling).

  • device (Optional[torch.device]) – Torch device for output tensors. If None, inferred from data.

Returns:

Embeddings tensor of shape (num_anchor_nodes, out_dim).

Return type:

torch.Tensor

class gigl.nn.GraphTransformerEncoderLayer(model_dim, num_heads, feedforward_dim, dropout_rate=0.1, attention_dropout_rate=0.0, activation='gelu')[source]#

Bases: torch.nn.Module

Pre-norm transformer encoder layer with multi-head self-attention.

Uses F.scaled_dot_product_attention which automatically selects the most efficient attention implementation (flash, memory-efficient, or math-based) based on input properties and hardware.

Adapted from RelGT’s EncoderLayer.

Parameters:
  • model_dim (int) – Model dimension (d_model).

  • num_heads (int) – Number of attention heads. Must evenly divide model_dim.

  • feedforward_dim (int) – Inner dimension of the feed-forward network.

  • dropout_rate (float) – Dropout probability for feed-forward layers.

  • attention_dropout_rate (float) – Dropout probability for attention weights.

  • activation (str) – Activation function for the feed-forward network. Supported values: “gelu” (default), “relu”, “silu”, “tanh”, “geglu”, “swiglu”, “reglu”.

Raises:

ValueError – If model_dim is not divisible by num_heads.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, attn_bias=None, valid_mask=None)[source]#

Forward pass.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch, seq, model_dim).

  • attn_bias (Optional[torch.Tensor]) – Optional attention bias of shape (batch, num_heads, seq, seq) or broadcastable. Added as an additive mask to attention scores.

  • valid_mask (Optional[torch.Tensor]) – Optional boolean tensor of shape (batch, seq) used to zero out padded token states after each residual block.

Returns:

Output tensor of shape (batch, seq, model_dim).

Return type:

torch.Tensor

reset_parameters()[source]#

Reinitialize all learnable parameters.

Return type:

None

class gigl.nn.LightGCN(node_type_to_num_nodes, embedding_dim=64, num_layers=2, device=torch.device('cpu'), layer_weights=None)[source]#

Bases: torch.nn.Module

LightGCN model with TorchRec integration for distributed ID embeddings.

Reference: https://arxiv.org/pdf/2002.02126

This class extends the basic LightGCN implementation to use TorchRec’s distributed embedding tables for handling large-scale ID embeddings.

Parameters:
  • node_type_to_num_nodes (Union[int, Dict[NodeType, int]]) – Map from node types to node counts. Can also pass a single int for homogeneous graphs.

  • embedding_dim (int) – Dimension of node embeddings D. Default: 64.

  • num_layers (int) – Number of LightGCN propagation layers K. Default: 2.

  • device (torch.device) – Device to run the computation on. Default: CPU.

  • layer_weights (Optional[List[float]]) – Weights for [e^(0), e^(1), …, e^(K)]. Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data, device, output_node_types=None, anchor_node_ids=None)[source]#

Forward pass of the LightGCN model.

Parameters:
  • data (Union[Data, HeteroData]) – Graph data (homogeneous or heterogeneous).

  • device (torch.device) – Device to run the computation on.

  • output_node_types (Optional[List[NodeType]]) – List of node types to return embeddings for. Required for heterogeneous graphs. Default: None.

  • anchor_node_ids (Optional[torch.Tensor]) – Local node indices to return embeddings for. If None, returns embeddings for all nodes. Default: None.

Returns:

Node embeddings.

For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim]. For heterogeneous graphs, returns dict mapping node types to embeddings.

Return type:

Union[torch.Tensor, Dict[NodeType, torch.Tensor]]

class gigl.nn.LinkPredictionGNN(encoder, decoder)[source]#

Bases: torch.nn.Module

Link Prediction GNN model for both homogeneous and heterogeneous use cases :param encoder: Either BasicGNN or Heterogeneous GNN for generating embeddings :type encoder: nn.Module :param decoder: Decoder for transforming embeddings into scores.

Recommended to use gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder

Parameters:
  • encoder (torch.nn.Module)

  • decoder (nn.Module)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

decode(query_embeddings, candidate_embeddings)[source]#
Parameters:
  • query_embeddings (torch.Tensor)

  • candidate_embeddings (torch.Tensor)

Return type:

torch.Tensor

forward(data, device, output_node_types=None)[source]#
Parameters:
  • data (Union[torch_geometric.data.Data, torch_geometric.data.HeteroData])

  • device (torch.device)

  • output_node_types (Optional[list[gigl.src.common.types.graph_data.NodeType]])

Return type:

Union[torch.Tensor, dict[gigl.src.common.types.graph_data.NodeType, torch.Tensor]]

to_ddp(device, find_unused_encoder_parameters=False)[source]#

Converts the model to DistributedDataParallel (DDP) mode.

We do this because DDP does not expect the forward method of the modules it wraps to be called directly. See how DistributedDataParallel.forward calls _pre_forward: pytorch/pytorch If we do not do this, then calling forward() on the individual modules may not work correctly.

Calling this function makes it safe to do: LinkPredictionGNN.decoder(data, device)

Parameters:
  • device (Optional[torch.device]) – The device to which the model should be moved. If None, will default to CPU.

  • find_unused_encoder_parameters (bool) – Whether to find unused parameters in the model. This should be set to True if the model has parameters that are not used in the forward pass.

Returns:

A new instance of LinkPredictionGNN for use with DDP.

Return type:

LinkPredictionGNN

unwrap_from_ddp()[source]#

Unwraps the model from DistributedDataParallel if it is wrapped.

Returns:

A new instance of LinkPredictionGNN with the original encoder and decoder.

Return type:

LinkPredictionGNN

property decoder: torch.nn.Module#
Return type:

torch.nn.Module

property encoder: torch.nn.Module#
Return type:

torch.nn.Module

class gigl.nn.RetrievalLoss(loss=None, temperature=None, remove_accidental_hits=False)[source]#

Bases: torch.nn.Module

A loss layer built on top of the tensorflow_recommenders implementation. https://www.tensorflow.org/recommenders/api_docs/python/tfrs/tasks/Retrieval

The loss function by default calculates the loss by: ` cross_entropy(torch.mm(query_embeddings, candidate_embeddings.T), positive_indices, reduction='sum'), ` where the candidate embeddings are torch.cat((positive_embeddings, random_negative_embeddings)). It encourages the model to generate query embeddings that yield the highest similarity score with their own first hop compared with others’ first hops and random negatives. We also filter out the cases where, in some rows, the query could accidentally treat its own positives as negatives.

Parameters:
  • loss (Optional[nn.Module]) – Custom loss function to be used. If None, the default is nn.CrossEntropyLoss(reduction=”sum”).

  • temperature (Optional[float]) – Temperature scaling applied to scores before computing cross-entropy loss. If not None, scores are divided by the temperature value.

  • remove_accidental_hits (bool) – Whether to remove accidental hits where the query’s positive items are also present in the negative samples.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(repeated_candidate_scores, candidate_ids, repeated_query_ids, device, candidate_sampling_probability=None)[source]#
Parameters:
  • repeated_candidate_scores (torch.Tensor) – The prediction scores between each repeated query users and each candidates. In this case, repeated means that we repeat each query user based on the number of positive labels they have. Tensor shape: [num_positives, num_positives + num_hard_negatives + num_random_negatives]

  • candidate_ids (torch.Tensor) – Concatenated Ids of the candidates. Tensor shape: [num_positives + num_hard_negatives + num_random_negatives]

  • repeated_query_ids (torch.Tensor) – Repeated query user IDs. Tensor shape: [num_positives]

  • candidate_sampling_probability (Optional[torch.Tensor]) – Optional tensor of candidate sampling probabilities. When given will be used to correct the logits to reflect the sampling probability of negative candidates. Tensor shape: [num_positives + num_hard_negatives + num_random_negatives]

  • device (torch.device)