Source code for gigl.module.models

from typing import Optional, Union

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from torch_geometric.data import Data, HeteroData
from typing_extensions import Self

from gigl.src.common.types.graph_data import NodeType


[docs] class LinkPredictionGNN(nn.Module): """ Link Prediction GNN model for both homogeneous and heterogeneous use cases Args: encoder (nn.Module): Either BasicGNN or Heterogeneous GNN for generating embeddings decoder (nn.Module): Decoder for transforming embeddings into scores. Recommended to use `gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder` """ def __init__( self, encoder: nn.Module, decoder: nn.Module, ) -> None: super().__init__() self._encoder = encoder self._decoder = decoder
[docs] def forward( self, data: Union[Data, HeteroData], device: torch.device, output_node_types: Optional[list[NodeType]] = None, ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: if isinstance(data, HeteroData): if output_node_types is None: raise ValueError( "Output node types must be specified in forward() pass for heterogeneous model" ) return self._encoder( data=data, output_node_types=output_node_types, device=device ) else: return self._encoder(data=data, device=device)
[docs] def decode( self, query_embeddings: torch.Tensor, candidate_embeddings: torch.Tensor, ) -> torch.Tensor: return self._decoder( query_embeddings=query_embeddings, candidate_embeddings=candidate_embeddings, )
@property
[docs] def encoder(self) -> nn.Module: return self._encoder
@property
[docs] def decoder(self) -> nn.Module: return self._decoder
[docs] def to_ddp( self, device: Optional[torch.device], find_unused_encoder_parameters: bool = False, ) -> Self: """ Converts the model to DistributedDataParallel (DDP) mode. We do this because DDP does *not* expect the forward method of the modules it wraps to be called directly. See how DistributedDataParallel.forward calls _pre_forward: https://github.com/pytorch/pytorch/blob/26807dcf277feb2d99ab88d7b6da526488baea93/torch/nn/parallel/distributed.py#L1657 If we do not do this, then calling forward() on the individual modules may not work correctly. Calling this function makes it safe to do: `LinkPredictionGNN.decoder(data, device)` Args: device (Optional[torch.device]): The device to which the model should be moved. If None, will default to CPU. find_unused_encoder_parameters (bool): Whether to find unused parameters in the model. This should be set to True if the model has parameters that are not used in the forward pass. Returns: LinkPredictionGNN: A new instance of LinkPredictionGNN for use with DDP. """ if device is None: device = torch.device("cpu") ddp_encoder = DistributedDataParallel( self._encoder.to(device), device_ids=[device] if device.type != "cpu" else None, find_unused_parameters=find_unused_encoder_parameters, ) # Do this "backwards" so the we can define "ddp_decoder" as a nn.Module first... if not any(p.requires_grad for p in self._decoder.parameters()): # If the decoder has no trainable parameters, we can just use it as is ddp_decoder = self._decoder.to(device) else: # Only wrap the decoder in DDP if it has parameters that require gradients # Otherwise DDP will complain about no parameters to train. ddp_decoder = DistributedDataParallel( self._decoder.to(device), device_ids=[device] if device.type != "cpu" else None, ) self._encoder = ddp_encoder self._decoder = ddp_decoder return self
[docs] def unwrap_from_ddp(self) -> "LinkPredictionGNN": """ Unwraps the model from DistributedDataParallel if it is wrapped. Returns: LinkPredictionGNN: A new instance of LinkPredictionGNN with the original encoder and decoder. """ if isinstance(self._encoder, DistributedDataParallel): encoder = self._encoder.module else: encoder = self._encoder if isinstance(self._decoder, DistributedDataParallel): decoder = self._decoder.module else: decoder = self._decoder return LinkPredictionGNN(encoder=encoder, decoder=decoder)