Source code for gigl.src.common.models.layers.decoder

from enum import Enum
from typing import Callable, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.models import MLP


[docs] class DecoderType(Enum):
[docs] hadamard_MLP = "hadamard_MLP"
[docs] inner_product = "inner_product"
@classmethod
[docs] def get_all_criteria(cls) -> List[str]: return [m.name for m in cls]
[docs] class LinkPredictionDecoder(nn.Module): def __init__( self, decoder_type: DecoderType = DecoderType.inner_product, decoder_channel_list: Optional[List[int]] = None, act: Union[str, Callable, None] = F.relu, act_first: bool = False, bias: Union[bool, List[bool]] = False, plain_last: bool = False, norm: Optional[Union[str, Callable]] = None, ): super(LinkPredictionDecoder, self).__init__()
[docs] self.decoder_type = decoder_type
[docs] self.decoder_channel_list = decoder_channel_list
if self.decoder_type.value == "hadamard_MLP" and not isinstance( self.decoder_channel_list, List ): raise ValueError( f"The decoder channel list must be provided when using 'hadamard_MLP' decoder, however you provided {self.decoder_channel_list}" ) if ( isinstance(self.decoder_channel_list, List) and len(self.decoder_channel_list) <= 1 ): raise ValueError( f"The decoder channel list must have length at least 2, however you provided a list of length {len(self.decoder_channel_list)}" ) if ( isinstance(self.decoder_channel_list, List) and self.decoder_channel_list[-1] != 1 ): raise ValueError( f"The last element in decoder channel list must be equal to 1, however you provided {self.decoder_channel_list[-1]}" ) if self.decoder_type.value == "hadamard_MLP": self.mlp_decoder = MLP( channel_list=self.decoder_channel_list, act=act, act_first=act_first, bias=bias, plain_last=plain_last, norm=norm, )
[docs] def forward(self, query_embeddings, candidate_embeddings) -> torch.Tensor: if self.decoder_type.value == "inner_product": scores = torch.mm(query_embeddings, candidate_embeddings.T) elif self.decoder_type.value == "hadamard_MLP": hadamard_scores = query_embeddings.unsqueeze(dim=1) * candidate_embeddings scores = self.mlp_decoder(hadamard_scores).sum(dim=-1) return scores