gigl.module.models#
Classes#
Link Prediction GNN model for both homogeneous and heterogeneous use cases |
Module Contents#
- class gigl.module.models.LinkPredictionGNN(encoder, decoder)[source]#
Bases:
torch.nn.Module
Link Prediction GNN model for both homogeneous and heterogeneous use cases :param encoder: Either BasicGNN or Heterogeneous GNN for generating embeddings :type encoder: nn.Module :param decoder: Decoder for transforming embeddings into scores.
Recommended to use gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder
- Parameters:
encoder (torch.nn.Module)
decoder (nn.Module)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- decode(query_embeddings, candidate_embeddings)[source]#
- Parameters:
query_embeddings (torch.Tensor)
candidate_embeddings (torch.Tensor)
- Return type:
torch.Tensor
- forward(data, device, output_node_types=None)[source]#
- Parameters:
data (Union[torch_geometric.data.Data, torch_geometric.data.HeteroData])
device (torch.device)
output_node_types (Optional[list[gigl.src.common.types.graph_data.NodeType]])
- Return type:
Union[torch.Tensor, dict[gigl.src.common.types.graph_data.NodeType, torch.Tensor]]
- to_ddp(device, find_unused_encoder_parameters=False)[source]#
Converts the model to DistributedDataParallel (DDP) mode.
We do this because DDP does not expect the forward method of the modules it wraps to be called directly. See how DistributedDataParallel.forward calls _pre_forward: pytorch/pytorch If we do not do this, then calling forward() on the individual modules may not work correctly.
Calling this function makes it safe to do: LinkPredictionGNN.decoder(data, device)
- Parameters:
device (Optional[torch.device]) – The device to which the model should be moved. If None, will default to CPU.
find_unused_encoder_parameters (bool) – Whether to find unused parameters in the model. This should be set to True if the model has parameters that are not used in the forward pass.
- Returns:
A new instance of LinkPredictionGNN for use with DDP.
- Return type: