gigl.module.models#

Classes#

LinkPredictionGNN

Link Prediction GNN model for both homogeneous and heterogeneous use cases

Module Contents#

class gigl.module.models.LinkPredictionGNN(encoder, decoder)[source]#

Bases: torch.nn.Module

Link Prediction GNN model for both homogeneous and heterogeneous use cases :param encoder: Either BasicGNN or Heterogeneous GNN for generating embeddings :type encoder: nn.Module :param decoder: LinkPredictionDecoder for transforming embeddings into scores :type decoder: nn.Module

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • encoder (torch.nn.Module)

  • decoder (gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder)

decode(query_embeddings, candidate_embeddings)[source]#
Parameters:
  • query_embeddings (torch.Tensor)

  • candidate_embeddings (torch.Tensor)

Return type:

torch.Tensor

forward(data, device, output_node_types=None)[source]#
Parameters:
  • data (Union[torch_geometric.data.Data, torch_geometric.data.HeteroData])

  • device (torch.device)

  • output_node_types (Optional[list[gigl.src.common.types.graph_data.NodeType]])

Return type:

Union[torch.Tensor, dict[gigl.src.common.types.graph_data.NodeType, torch.Tensor]]