gigl.module.models#
Classes#
Link Prediction GNN model for both homogeneous and heterogeneous use cases |
Module Contents#
- class gigl.module.models.LinkPredictionGNN(encoder, decoder)[source]#
Bases:
torch.nn.Module
Link Prediction GNN model for both homogeneous and heterogeneous use cases :param encoder: Either BasicGNN or Heterogeneous GNN for generating embeddings :type encoder: nn.Module :param decoder: LinkPredictionDecoder for transforming embeddings into scores :type decoder: nn.Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
encoder (torch.nn.Module)
decoder (gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder)
- decode(query_embeddings, candidate_embeddings)[source]#
- Parameters:
query_embeddings (torch.Tensor)
candidate_embeddings (torch.Tensor)
- Return type:
torch.Tensor
- forward(data, device, output_node_types=None)[source]#
- Parameters:
data (Union[torch_geometric.data.Data, torch_geometric.data.HeteroData])
device (torch.device)
output_node_types (Optional[list[gigl.src.common.types.graph_data.NodeType]])
- Return type:
Union[torch.Tensor, dict[gigl.src.common.types.graph_data.NodeType, torch.Tensor]]