gigl.src.common.models.pyg.link_prediction#
Classes#
Link Prediction GNN model for both homogeneous and heterogeneous use cases |
Module Contents#
- class gigl.src.common.models.pyg.link_prediction.LinkPredictionGNN(encoder, decoder, tasks=None)[source]#
Bases:
torch.nn.Module
Link Prediction GNN model for both homogeneous and heterogeneous use cases :param encoder: Either BasicGNN or Heterogeneous GNN for generating embeddings :type encoder: nn.Module :param decoder: LinkPredictionDecoder for transforming embeddings into scores :type decoder: nn.Module :param tasks: Learning tasks for training (i.e. Retrieval, Margin, SSL, …) :type tasks: NodeAnchorBasedLinkPredictionTasks
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
encoder (torch.nn.Module)
decoder (gigl.src.common.models.layers.decoder.LinkPredictionDecoder)
tasks (Optional[gigl.src.common.models.layers.task.NodeAnchorBasedLinkPredictionTasks])
- decode(query_embeddings, candidate_embeddings)[source]#
- Parameters:
query_embeddings (torch.Tensor)
candidate_embeddings (torch.Tensor)
- Return type:
torch.Tensor
- forward(data, output_node_types, device)[source]#
- Parameters:
data (Union[torch_geometric.data.Data, torch_geometric.data.hetero_data.HeteroData])
output_node_types (List[gigl.src.common.types.graph_data.NodeType])
device (torch.device)
- Return type:
Dict[gigl.src.common.types.graph_data.NodeType, torch.Tensor]
- property graph_backend: gigl.src.common.types.model.GraphBackend[source]#
- Return type: