Source code for gigl.src.common.models.pyg.link_prediction

from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
import torch_geometric

from gigl.common.logger import Logger
from gigl.src.common.models.layers.decoder import LinkPredictionDecoder
from gigl.src.common.models.layers.task import NodeAnchorBasedLinkPredictionTasks
from gigl.src.common.types.graph_data import NodeType
from gigl.src.common.types.model import GraphBackend

[docs] logger = Logger()
[docs] class LinkPredictionGNN(nn.Module): """ Link Prediction GNN model for both homogeneous and heterogeneous use cases Args: encoder (nn.Module): Either BasicGNN or Heterogeneous GNN for generating embeddings decoder (nn.Module): LinkPredictionDecoder for transforming embeddings into scores tasks (NodeAnchorBasedLinkPredictionTasks): Learning tasks for training (i.e. Retrieval, Margin, SSL, ...) """ def __init__( self, encoder: nn.Module, decoder: LinkPredictionDecoder, tasks: Optional[NodeAnchorBasedLinkPredictionTasks] = None, ) -> None: super().__init__() logger.warning( "gigl.src.common.models.layers.nn.link_prediction.LinkPredictionGNN is deprecated and will be removed in a future release. " "Please use the `gigl.module.models.LinkPredictionGNN` class instead." ) self.__encoder = encoder self.__decoder = decoder self.__tasks = tasks
[docs] def forward( self, data: Union[ torch_geometric.data.Data, torch_geometric.data.hetero_data.HeteroData ], output_node_types: List[NodeType], device: torch.device, ) -> Dict[NodeType, torch.Tensor]: if isinstance(data, torch_geometric.data.hetero_data.HeteroData): return self.__encoder( data=data, output_node_types=output_node_types, device=device ) else: if len(output_node_types) > 1: raise NotImplementedError( f"Found {len(output_node_types)} output node types for homogeneous data, which must have one node type" ) return {output_node_types[0]: self.__encoder(data=data, device=device)}
[docs] def decode( self, query_embeddings: torch.Tensor, candidate_embeddings: torch.Tensor, ) -> torch.Tensor: return self.__decoder( query_embeddings=query_embeddings, candidate_embeddings=candidate_embeddings, )
@property
[docs] def tasks(self) -> NodeAnchorBasedLinkPredictionTasks: return self.__tasks # type: ignore
@property
[docs] def graph_backend(self) -> GraphBackend: return GraphBackend.PYG