gigl.src.common.models.pyg.link_prediction#
Attributes#
Classes#
| Link Prediction GNN model for both homogeneous and heterogeneous use cases | 
Module Contents#
- class gigl.src.common.models.pyg.link_prediction.LinkPredictionGNN(encoder, decoder, tasks=None)[source]#
- Bases: - torch.nn.Module- Link Prediction GNN model for both homogeneous and heterogeneous use cases :param encoder: Either BasicGNN or Heterogeneous GNN for generating embeddings :type encoder: nn.Module :param decoder: LinkPredictionDecoder for transforming embeddings into scores :type decoder: nn.Module :param tasks: Learning tasks for training (i.e. Retrieval, Margin, SSL, …) :type tasks: NodeAnchorBasedLinkPredictionTasks - Initialize internal Module state, shared by both nn.Module and ScriptModule. - Parameters:
- encoder (torch.nn.Module) 
- decoder (gigl.src.common.models.layers.decoder.LinkPredictionDecoder) 
- tasks (Optional[gigl.src.common.models.layers.task.NodeAnchorBasedLinkPredictionTasks]) 
 
 - decode(query_embeddings, candidate_embeddings)[source]#
- Parameters:
- query_embeddings (torch.Tensor) 
- candidate_embeddings (torch.Tensor) 
 
- Return type:
- torch.Tensor 
 
 - forward(data, output_node_types, device)[source]#
- Parameters:
- data (Union[torch_geometric.data.Data, torch_geometric.data.hetero_data.HeteroData]) 
- output_node_types (list[gigl.src.common.types.graph_data.NodeType]) 
- device (torch.device) 
 
- Return type:
- dict[gigl.src.common.types.graph_data.NodeType, torch.Tensor] 
 
 - property graph_backend: gigl.src.common.types.model.GraphBackend[source]#
- Return type:
 
 
