Source code for gigl.src.common.models.layers.decoder
from enum import Enum
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.models import MLP
[docs]
class DecoderType(Enum):
[docs]
    hadamard_MLP = "hadamard_MLP" 
[docs]
    inner_product = "inner_product" 
    @classmethod
[docs]
    def get_all_criteria(cls) -> list[str]:
        return [m.name for m in cls] 
 
[docs]
class LinkPredictionDecoder(nn.Module):
    def __init__(
        self,
        decoder_type: DecoderType = DecoderType.inner_product,
        decoder_channel_list: Optional[list[int]] = None,
        act: Union[str, Callable, None] = F.relu,
        act_first: bool = False,
        bias: Union[bool, list[bool]] = False,
        plain_last: bool = False,
        norm: Optional[Union[str, Callable]] = None,
    ):
        super(LinkPredictionDecoder, self).__init__()
[docs]
        self.decoder_type = decoder_type 
[docs]
        self.decoder_channel_list = decoder_channel_list 
        if self.decoder_type.value == "hadamard_MLP" and not isinstance(
            self.decoder_channel_list, list
        ):
            raise ValueError(
                f"The decoder channel list must be provided when using 'hadamard_MLP' decoder, however you provided {self.decoder_channel_list}"
            )
        if (
            isinstance(self.decoder_channel_list, list)
            and len(self.decoder_channel_list) <= 1
        ):
            raise ValueError(
                f"The decoder channel list must have length at least 2, however you provided a list of length {len(self.decoder_channel_list)}"
            )
        if (
            isinstance(self.decoder_channel_list, list)
            and self.decoder_channel_list[-1] != 1
        ):
            raise ValueError(
                f"The last element in decoder channel list must be equal to 1, however you provided {self.decoder_channel_list[-1]}"
            )
        if self.decoder_type.value == "hadamard_MLP":
            self.mlp_decoder = MLP(
                channel_list=self.decoder_channel_list,
                act=act,
                act_first=act_first,
                bias=bias,
                plain_last=plain_last,
                norm=norm,
            )
[docs]
    def forward(self, query_embeddings, candidate_embeddings) -> torch.Tensor:
        if self.decoder_type.value == "inner_product":
            scores = torch.mm(query_embeddings, candidate_embeddings.T)
        elif self.decoder_type.value == "hadamard_MLP":
            hadamard_scores = query_embeddings.unsqueeze(dim=1) * candidate_embeddings
            scores = self.mlp_decoder(hadamard_scores).sum(dim=-1)
        return scores