Source code for gigl.module.models

from typing import Optional, Union

import torch
import torch.nn as nn
from torch_geometric.data import Data, HeteroData

from gigl.src.common.models.pyg.link_prediction import LinkPredictionDecoder
from gigl.src.common.types.graph_data import NodeType


[docs] class LinkPredictionGNN(nn.Module): """ Link Prediction GNN model for both homogeneous and heterogeneous use cases Args: encoder (nn.Module): Either BasicGNN or Heterogeneous GNN for generating embeddings decoder (nn.Module): LinkPredictionDecoder for transforming embeddings into scores """ def __init__( self, encoder: nn.Module, decoder: LinkPredictionDecoder, ) -> None: super().__init__() self._encoder = encoder self._decoder = decoder
[docs] def forward( self, data: Union[Data, HeteroData], device: torch.device, output_node_types: Optional[list[NodeType]] = None, ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: if isinstance(data, HeteroData): if output_node_types is None: raise ValueError( "Output node types must be specified in forward() pass for heterogeneous model" ) return self._encoder( data=data, output_node_types=output_node_types, device=device ) else: return self._encoder(data=data, device=device)
[docs] def decode( self, query_embeddings: torch.Tensor, candidate_embeddings: torch.Tensor, ) -> torch.Tensor: return self._decoder( query_embeddings=query_embeddings, candidate_embeddings=candidate_embeddings, )