gigl.module.models#
Classes#
| LightGCN model with TorchRec integration for distributed ID embeddings. | |
| Link Prediction GNN model for both homogeneous and heterogeneous use cases | 
Module Contents#
- class gigl.module.models.LightGCN(node_type_to_num_nodes, embedding_dim=64, num_layers=2, device=torch.device('cpu'), layer_weights=None)[source]#
- Bases: - torch.nn.Module- LightGCN model with TorchRec integration for distributed ID embeddings. - Reference: https://arxiv.org/pdf/2002.02126 - This class extends the basic LightGCN implementation to use TorchRec’s distributed embedding tables for handling large-scale ID embeddings. - Parameters:
- node_type_to_num_nodes (Union[int, Dict[NodeType, int]]) – Map from node types to node counts. Can also pass a single int for homogeneous graphs. 
- embedding_dim (int) – Dimension of node embeddings D. Default: 64. 
- num_layers (int) – Number of LightGCN propagation layers K. Default: 2. 
- device (torch.device) – Device to run the computation on. Default: CPU. 
- layer_weights (Optional[List[float]]) – Weights for [e^(0), e^(1), …, e^(K)]. Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None. 
 
 - Initialize internal Module state, shared by both nn.Module and ScriptModule. - forward(data, device, output_node_types=None, anchor_node_ids=None)[source]#
- Forward pass of the LightGCN model. - Parameters:
- data (Union[Data, HeteroData]) – Graph data (homogeneous or heterogeneous). 
- device (torch.device) – Device to run the computation on. 
- output_node_types (Optional[List[NodeType]]) – List of node types to return embeddings for. Required for heterogeneous graphs. Default: None. 
- anchor_node_ids (Optional[torch.Tensor]) – Local node indices to return embeddings for. If None, returns embeddings for all nodes. Default: None. 
 
- Returns:
- Node embeddings.
- For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim]. For heterogeneous graphs, returns dict mapping node types to embeddings. 
 
- Return type:
- Union[torch.Tensor, Dict[NodeType, torch.Tensor]] 
 
 
- class gigl.module.models.LinkPredictionGNN(encoder, decoder)[source]#
- Bases: - torch.nn.Module- Link Prediction GNN model for both homogeneous and heterogeneous use cases :param encoder: Either BasicGNN or Heterogeneous GNN for generating embeddings :type encoder: nn.Module :param decoder: Decoder for transforming embeddings into scores. - Recommended to use gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder - Parameters:
- encoder (torch.nn.Module) 
- decoder (nn.Module) 
 
 - Initialize internal Module state, shared by both nn.Module and ScriptModule. - decode(query_embeddings, candidate_embeddings)[source]#
- Parameters:
- query_embeddings (torch.Tensor) 
- candidate_embeddings (torch.Tensor) 
 
- Return type:
- torch.Tensor 
 
 - forward(data, device, output_node_types=None)[source]#
- Parameters:
- data (Union[torch_geometric.data.Data, torch_geometric.data.HeteroData]) 
- device (torch.device) 
- output_node_types (Optional[list[gigl.src.common.types.graph_data.NodeType]]) 
 
- Return type:
- Union[torch.Tensor, dict[gigl.src.common.types.graph_data.NodeType, torch.Tensor]] 
 
 - to_ddp(device, find_unused_encoder_parameters=False)[source]#
- Converts the model to DistributedDataParallel (DDP) mode. - We do this because DDP does not expect the forward method of the modules it wraps to be called directly. See how DistributedDataParallel.forward calls _pre_forward: pytorch/pytorch If we do not do this, then calling forward() on the individual modules may not work correctly. - Calling this function makes it safe to do: LinkPredictionGNN.decoder(data, device) - Parameters:
- device (Optional[torch.device]) – The device to which the model should be moved. If None, will default to CPU. 
- find_unused_encoder_parameters (bool) – Whether to find unused parameters in the model. This should be set to True if the model has parameters that are not used in the forward pass. 
 
- Returns:
- A new instance of LinkPredictionGNN for use with DDP. 
- Return type:
 
 
