gigl.nn#
GiGL NN Module
Submodules#
Classes#
LightGCN model with TorchRec integration for distributed ID embeddings. |
|
Link Prediction GNN model for both homogeneous and heterogeneous use cases |
|
A loss layer built on top of the tensorflow_recommenders implementation. |
Package Contents#
- class gigl.nn.LightGCN(node_type_to_num_nodes, embedding_dim=64, num_layers=2, device=torch.device('cpu'), layer_weights=None)[source]#
Bases:
torch.nn.ModuleLightGCN model with TorchRec integration for distributed ID embeddings.
Reference: https://arxiv.org/pdf/2002.02126
This class extends the basic LightGCN implementation to use TorchRec’s distributed embedding tables for handling large-scale ID embeddings.
- Parameters:
node_type_to_num_nodes (Union[int, Dict[NodeType, int]]) – Map from node types to node counts. Can also pass a single int for homogeneous graphs.
embedding_dim (int) – Dimension of node embeddings D. Default: 64.
num_layers (int) – Number of LightGCN propagation layers K. Default: 2.
device (torch.device) – Device to run the computation on. Default: CPU.
layer_weights (Optional[List[float]]) – Weights for [e^(0), e^(1), …, e^(K)]. Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data, device, output_node_types=None, anchor_node_ids=None)[source]#
Forward pass of the LightGCN model.
- Parameters:
data (Union[Data, HeteroData]) – Graph data (homogeneous or heterogeneous).
device (torch.device) – Device to run the computation on.
output_node_types (Optional[List[NodeType]]) – List of node types to return embeddings for. Required for heterogeneous graphs. Default: None.
anchor_node_ids (Optional[torch.Tensor]) – Local node indices to return embeddings for. If None, returns embeddings for all nodes. Default: None.
- Returns:
- Node embeddings.
For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim]. For heterogeneous graphs, returns dict mapping node types to embeddings.
- Return type:
Union[torch.Tensor, Dict[NodeType, torch.Tensor]]
- class gigl.nn.LinkPredictionGNN(encoder, decoder)[source]#
Bases:
torch.nn.ModuleLink Prediction GNN model for both homogeneous and heterogeneous use cases :param encoder: Either BasicGNN or Heterogeneous GNN for generating embeddings :type encoder: nn.Module :param decoder: Decoder for transforming embeddings into scores.
Recommended to use gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder
- Parameters:
encoder (torch.nn.Module)
decoder (nn.Module)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- decode(query_embeddings, candidate_embeddings)[source]#
- Parameters:
query_embeddings (torch.Tensor)
candidate_embeddings (torch.Tensor)
- Return type:
torch.Tensor
- forward(data, device, output_node_types=None)[source]#
- Parameters:
data (Union[torch_geometric.data.Data, torch_geometric.data.HeteroData])
device (torch.device)
output_node_types (Optional[list[gigl.src.common.types.graph_data.NodeType]])
- Return type:
Union[torch.Tensor, dict[gigl.src.common.types.graph_data.NodeType, torch.Tensor]]
- to_ddp(device, find_unused_encoder_parameters=False)[source]#
Converts the model to DistributedDataParallel (DDP) mode.
We do this because DDP does not expect the forward method of the modules it wraps to be called directly. See how DistributedDataParallel.forward calls _pre_forward: pytorch/pytorch If we do not do this, then calling forward() on the individual modules may not work correctly.
Calling this function makes it safe to do: LinkPredictionGNN.decoder(data, device)
- Parameters:
device (Optional[torch.device]) – The device to which the model should be moved. If None, will default to CPU.
find_unused_encoder_parameters (bool) – Whether to find unused parameters in the model. This should be set to True if the model has parameters that are not used in the forward pass.
- Returns:
A new instance of LinkPredictionGNN for use with DDP.
- Return type:
- unwrap_from_ddp()[source]#
Unwraps the model from DistributedDataParallel if it is wrapped.
- Returns:
A new instance of LinkPredictionGNN with the original encoder and decoder.
- Return type:
- property decoder: torch.nn.Module#
- Return type:
torch.nn.Module
- property encoder: torch.nn.Module#
- Return type:
torch.nn.Module
- class gigl.nn.RetrievalLoss(loss=None, temperature=None, remove_accidental_hits=False)[source]#
Bases:
torch.nn.ModuleA loss layer built on top of the tensorflow_recommenders implementation. https://www.tensorflow.org/recommenders/api_docs/python/tfrs/tasks/Retrieval
The loss function by default calculates the loss by:
` cross_entropy(torch.mm(query_embeddings, candidate_embeddings.T), positive_indices, reduction='sum'), `where the candidate embeddings are torch.cat((positive_embeddings, random_negative_embeddings)). It encourages the model to generate query embeddings that yield the highest similarity score with their own first hop compared with others’ first hops and random negatives. We also filter out the cases where, in some rows, the query could accidentally treat its own positives as negatives.- Parameters:
loss (Optional[nn.Module]) – Custom loss function to be used. If None, the default is nn.CrossEntropyLoss(reduction=”sum”).
temperature (Optional[float]) – Temperature scaling applied to scores before computing cross-entropy loss. If not None, scores are divided by the temperature value.
remove_accidental_hits (bool) – Whether to remove accidental hits where the query’s positive items are also present in the negative samples.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(repeated_candidate_scores, candidate_ids, repeated_query_ids, device, candidate_sampling_probability=None)[source]#
- Parameters:
repeated_candidate_scores (torch.Tensor) – The prediction scores between each repeated query users and each candidates. In this case, repeated means that we repeat each query user based on the number of positive labels they have. Tensor shape: [num_positives, num_positives + num_hard_negatives + num_random_negatives]
candidate_ids (torch.Tensor) – Concatenated Ids of the candidates. Tensor shape: [num_positives + num_hard_negatives + num_random_negatives]
repeated_query_ids (torch.Tensor) – Repeated query user IDs. Tensor shape: [num_positives]
candidate_sampling_probability (Optional[torch.Tensor]) – Optional tensor of candidate sampling probabilities. When given will be used to correct the logits to reflect the sampling probability of negative candidates. Tensor shape: [num_positives + num_hard_negatives + num_random_negatives]
device (torch.device)