Source code for gigl.module.models
from typing import Optional, Union
import torch
import torch.nn as nn
from torch_geometric.data import Data, HeteroData
from gigl.src.common.models.pyg.link_prediction import LinkPredictionDecoder
from gigl.src.common.types.graph_data import NodeType
[docs]
class LinkPredictionGNN(nn.Module):
"""
Link Prediction GNN model for both homogeneous and heterogeneous use cases
Args:
encoder (nn.Module): Either BasicGNN or Heterogeneous GNN for generating embeddings
decoder (nn.Module): LinkPredictionDecoder for transforming embeddings into scores
"""
def __init__(
self,
encoder: nn.Module,
decoder: LinkPredictionDecoder,
) -> None:
super().__init__()
self._encoder = encoder
self._decoder = decoder
[docs]
def forward(
self,
data: Union[Data, HeteroData],
device: torch.device,
output_node_types: Optional[list[NodeType]] = None,
) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]:
if isinstance(data, HeteroData):
if output_node_types is None:
raise ValueError(
"Output node types must be specified in forward() pass for heterogeneous model"
)
return self._encoder(
data=data, output_node_types=output_node_types, device=device
)
else:
return self._encoder(data=data, device=device)
[docs]
def decode(
self,
query_embeddings: torch.Tensor,
candidate_embeddings: torch.Tensor,
) -> torch.Tensor:
return self._decoder(
query_embeddings=query_embeddings,
candidate_embeddings=candidate_embeddings,
)