from typing import Any, Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.data
from torch_geometric.nn import (
    GATConv,
    GATv2Conv,
    GCNConv,
    GINConv,
    SAGEConv,
    TransformerConv,
)
from torch_geometric.nn.models import MLP
from gigl.common.logger import Logger
from gigl.src.common.constants.training import DEFAULT_NUM_GNN_HOPS
from gigl.src.common.models.layers.normalization import l2_normalize_embeddings
from gigl.src.common.models.pyg import utils as pyg_utils
from gigl.src.common.models.pyg.nn.conv.edge_attr_gat_conv import EdgeAttrGATConv
from gigl.src.common.models.pyg.nn.conv.gin_conv import GINEConv
from gigl.src.common.models.pyg.nn.models.feature_embedding import FeatureEmbeddingLayer
from gigl.src.common.models.pyg.nn.models.feature_interaction import FeatureInteraction
from gigl.src.common.models.pyg.nn.models.jumping_knowledge import JumpingKnowledge
from gigl.src.common.types.model import GnnModel, GraphBackend
[docs]
class BasicHomogeneousGNN(nn.Module, GnnModel):
    def __init__(
        self,
        in_dim: int,
        hid_dim: int,
        out_dim: int,
        conv_kwargs: dict[str, Any] = {},
        edge_dim: Optional[int] = None,
        num_layers: int = DEFAULT_NUM_GNN_HOPS,
        activation: Callable = F.relu,
        activation_before_norm: bool = False,  # apply activation function before normalization
        activation_after_last_conv: bool = False,  # apply activation after the last conv layer
        dropout: float = 0.0,  # dropout will auto set to 0.0 when model.eval()
        batchnorm: bool = False,  # batch norm
        linear_layer: bool = False,
        return_emb: bool = False,
        should_l2_normalize_embedding_layer_output: bool = False,
        jk_mode: Optional[str] = None,
        jk_lstm_dim: Optional[int] = None,
        feature_interaction_layer: Optional[FeatureInteraction] = None,
        feature_embedding_layer: Optional[FeatureEmbeddingLayer] = None,
        **kwargs,
    ):
        super(BasicHomogeneousGNN, self).__init__()
[docs]
        self.activation = activation 
[docs]
        self.activation_before_norm = activation_before_norm 
[docs]
        self.activation_after_last_conv = activation_after_last_conv 
[docs]
        self.dropout = nn.Dropout(p=dropout) 
[docs]
        self.batchnorm = batchnorm 
[docs]
        self.num_layers = num_layers 
        # Feature embedding layer to pass selected features through an embedding layer
[docs]
        self.feature_embedding_layer = feature_embedding_layer 
        # Feature interaction layers
[docs]
        self.feats_interaction = feature_interaction_layer 
[docs]
        self.conv_layers: nn.ModuleList = self.init_conv_layers(  # type: ignore
            in_dim=in_dim,
            out_dim=hid_dim if linear_layer or jk_mode else out_dim,
            edge_dim=edge_dim,
            hid_dim=hid_dim,
            num_layers=num_layers,
            **conv_kwargs,
        ) 
        if batchnorm:
            num_heads = int(conv_kwargs.get("heads", 1))
            num_batchnorm_layers = num_layers if jk_mode else num_layers - 1
            self.batchnorm_layers = nn.ModuleList(
                [
                    nn.BatchNorm1d(hid_dim * num_heads)
                    for i in range(num_batchnorm_layers)
                ]
            )
[docs]
        self.should_l2_normalize_embedding_layer_output = (
            should_l2_normalize_embedding_layer_output
        ) 
        if jk_mode:
            self.jk_layer = JumpingKnowledge(
                mode=jk_mode,
                hid_dim=hid_dim,
                out_dim=out_dim if not linear_layer else hid_dim,
                num_layers=num_layers,
                lstm_dim=jk_lstm_dim,
            )
        else:
            self.jk_layer = None  # type: ignore
[docs]
        self.return_emb = return_emb 
[docs]
        self.linear_layer = linear_layer 
        if linear_layer:
            self.linear = nn.Linear(hid_dim, out_dim)
[docs]
    def forward(
        self,
        data: torch_geometric.data.Data,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        # pass selected features through an embedding layer
        if self.feature_embedding_layer:
            x = self.feature_embedding_layer(x)
        # node feature interaction before graph convolution
        if self.feats_interaction:
            x = self.feats_interaction(x)
        xs: list[torch.Tensor] = []
        for i, conv_layer in enumerate(self.conv_layers):
            if self.supports_edge_attr:
                x = conv_layer(x=x, edge_index=edge_index, edge_attr=edge_attr)
            else:
                x = conv_layer(x=x, edge_index=edge_index)
            # exclude batch norm, activation, dropout after last layer
            if (
                i == self.num_layers - 1
                and not self.jk_layer
                and not self.activation_after_last_conv
            ):
                break
            if self.activation_before_norm:
                x = self.activation(x)
            if self.batchnorm:
                x = self.batchnorm_layers[i](x)
            if not self.activation_before_norm:
                x = self.activation(x)
            x = self.dropout(x)
            if self.jk_layer:
                xs.append(x)
        if self.jk_layer:
            x = self.jk_layer(xs)
        if self.should_l2_normalize_embedding_layer_output:
            x = l2_normalize_embeddings(node_typed_embeddings=x)
        if self.return_emb:
            return x
        if self.linear_layer:
            x = self.linear(x)
        return x 
[docs]
    def init_conv_layers(
        self,
        in_dim: Union[int, Tuple[int, int]],
        out_dim: int,
        edge_dim: Optional[int],
        hid_dim: int,
        num_layers: int,
        **kwargs,
    ) -> nn.ModuleList:
        raise NotImplementedError 
    @property
[docs]
    def graph_backend(self) -> GraphBackend:
        return GraphBackend.PYG 
 
[docs]
class GraphSAGE(BasicHomogeneousGNN):
[docs]
    supports_edge_weight = False 
[docs]
    supports_edge_attr = False 
[docs]
    def init_conv_layers(
        self,
        in_dim: Union[int, Tuple[int, int]],
        out_dim: int,
        edge_dim: Optional[int],
        hid_dim: int,
        num_layers: int,
        **kwargs,
    ) -> nn.ModuleList:
        remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict(
            input_dict=kwargs,
            keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS
            + ["aggr", "normalize", "root_weight", "project", "bias"],
        )
        logger.info(
            f"Discarded kwargs for {SAGEConv.__name__}: {discarded_kwargs.keys()}"
        )
        conv_layers = nn.ModuleList(
            [
                SAGEConv(
                    in_channels=in_dim if i == 0 else hid_dim,
                    out_channels=hid_dim if i < num_layers - 1 else out_dim,
                    **remaining_kwargs,
                )
                for i in range(num_layers)
            ]
        )
        return conv_layers 
 
[docs]
class GIN(BasicHomogeneousGNN):
[docs]
    supports_edge_weight = False 
[docs]
    supports_edge_attr = False 
[docs]
    def init_conv_layers(
        self,
        in_dim: Union[int, Tuple[int, int]],
        out_dim: int,
        edge_dim: Optional[int],
        hid_dim: int,
        num_layers: int,
        **kwargs,
    ) -> nn.ModuleList:
        eps: float = kwargs.pop("eps", 0.0)
        train_eps: bool = kwargs.pop("train_eps", False)
        remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict(
            input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS
        )
        logger.info(
            f"Discarded kwargs for {GINConv.__name__}: {discarded_kwargs.keys()}"
        )
        conv_layers = nn.ModuleList(
            [
                GINConv(
                    nn=MLP(
                        [
                            in_dim if i == 0 else hid_dim,
                            hid_dim if i < num_layers - 1 else out_dim,
                            hid_dim if i < num_layers - 1 else out_dim,
                        ],
                        act=self.activation,
                        act_first=self.activation_before_norm,
                        # Note: PyG has its own BatchNorm class so this BatchNorm won't be converted to torch.nn.SyncBatchNorm
                        norm="batch_norm" if self.batchnorm else None,
                    ),
                    eps=eps,
                    train_eps=train_eps,
                    **remaining_kwargs,
                )
                for i in range(num_layers)
            ]
        )
        return conv_layers 
 
[docs]
class GINE(BasicHomogeneousGNN):
[docs]
    supports_edge_weight = False 
[docs]
    supports_edge_attr = True 
[docs]
    def init_conv_layers(
        self,
        in_dim: Union[int, Tuple[int, int]],
        out_dim: int,
        edge_dim: Optional[int],
        hid_dim: int,
        num_layers: int,
        **kwargs,
    ) -> nn.ModuleList:
        eps: float = kwargs.pop("eps", 0.0)
        train_eps: bool = kwargs.pop("train_eps", False)
        remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict(
            input_dict=kwargs, keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS
        )
        logger.info(
            f"Discarded kwargs for {GINEConv.__name__}: {discarded_kwargs.keys()}"
        )
        conv_layers = nn.ModuleList(
            [
                GINEConv(
                    nn=MLP(
                        [
                            in_dim if i == 0 else hid_dim,
                            hid_dim if i < num_layers - 1 else out_dim,
                            hid_dim if i < num_layers - 1 else out_dim,
                        ],
                        act=self.activation,
                        act_first=self.activation_before_norm,
                        # Note: PyG has its own BatchNorm class so this BatchNorm won't be converted to torch.nn.SyncBatchNorm
                        norm="batch_norm" if self.batchnorm else None,
                    ),
                    eps=eps,
                    train_eps=train_eps,
                    edge_dim=edge_dim,
                    **remaining_kwargs,
                )
                for i in range(num_layers)
            ]
        )
        return conv_layers 
 
[docs]
class GAT(BasicHomogeneousGNN):
[docs]
    supports_edge_weight = False 
[docs]
    supports_edge_attr = True 
[docs]
    def init_conv_layers(
        self,
        in_dim: Union[int, Tuple[int, int]],
        out_dim: int,
        edge_dim: Optional[int],
        hid_dim: int,
        num_layers: int,
        **kwargs,
    ) -> nn.ModuleList:
        num_heads = int(kwargs.pop("heads", 1))
        remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict(
            input_dict=kwargs,
            keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS
            + [
                "concat",
                "negative_slope",
                "dropout",
                "add_self_loops",
                "fill_value",
                "bias",
            ],
        )
        logger.info(
            f"Discarded kwargs for {GATConv.__name__}: {discarded_kwargs.keys()}"
        )
        conv_layers = nn.ModuleList(
            [
                GATConv(
                    in_channels=in_dim if i == 0 else hid_dim * num_heads,
                    out_channels=hid_dim if i < num_layers - 1 else out_dim,
                    edge_dim=edge_dim,
                    heads=num_heads if i < num_layers - 1 else 1,
                    **remaining_kwargs,
                )
                for i in range(num_layers)
            ]
        )
        return conv_layers 
 
[docs]
class GATv2(BasicHomogeneousGNN):
[docs]
    supports_edge_weight = False 
[docs]
    supports_edge_attr = True 
[docs]
    def init_conv_layers(
        self,
        in_dim: Union[int, Tuple[int, int]],
        out_dim: int,
        edge_dim: Optional[int],
        hid_dim: int,
        num_layers: int,
        **kwargs,
    ) -> nn.ModuleList:
        num_heads = kwargs.pop("heads", 1)
        fill_value = kwargs.pop("fill_value", "mean")
        share_weights = kwargs.pop("share_weights", False)
        remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict(
            input_dict=kwargs,
            keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS
            + ["concat", "negative_slope", "dropout", "add_self_loops", "bias"],
        )
        logger.info(
            f"Discarded kwargs for {GATv2Conv.__name__}: {discarded_kwargs.keys()}"
        )
        conv_layers = nn.ModuleList(
            [
                GATv2Conv(
                    in_channels=in_dim if i == 0 else hid_dim * num_heads,
                    out_channels=hid_dim if i < num_layers - 1 else out_dim,
                    edge_dim=edge_dim,
                    heads=num_heads if i < num_layers - 1 else 1,
                    fill_value=fill_value,
                    share_weights=share_weights,
                    **remaining_kwargs,
                )
                for i in range(num_layers)
            ]
        )
        return conv_layers 
 
[docs]
class EdgeAttrGAT(BasicHomogeneousGNN):
[docs]
    supports_edge_weight = False 
[docs]
    supports_edge_attr = True 
[docs]
    def init_conv_layers(
        self,
        in_dim: Union[int, Tuple[int, int]],
        out_dim: int,
        edge_dim: Optional[int],
        hid_dim: int,
        num_layers: int,
        **kwargs,
    ) -> nn.ModuleList:
        num_heads = int(kwargs.pop("heads", 1))
        share_edge_att_message_weight = kwargs.pop(
            "share_edge_att_message_weight", True
        )
        remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict(
            input_dict=kwargs,
            keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS
            + [
                "concat",
                "negative_slope",
                "dropout",
                "add_self_loops",
                "fill_value",
                "bias",
            ],
        )
        logger.info(
            f"Discarded kwargs for {EdgeAttrGATConv.__name__}: {discarded_kwargs.keys()}"
        )
        conv_layers = nn.ModuleList(
            [
                EdgeAttrGATConv(
                    in_channels=in_dim if i == 0 else hid_dim * num_heads,
                    out_channels=hid_dim if i < num_layers - 1 else out_dim,
                    edge_dim=edge_dim,
                    heads=num_heads if i < num_layers - 1 else 1,
                    share_edge_att_message_weight=share_edge_att_message_weight,
                    **remaining_kwargs,
                )
                for i in range(num_layers)
            ]
        )
        return conv_layers 
 
[docs]
class TwoLayerGCN(torch.nn.Module, GnnModel):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        hid_dim: int = 16,
        is_training: bool = True,
        should_l2_normalize_output: bool = False,
        **kwargs,
    ):
        """
        Simple 2 layer GCN Implementation using PyG constructs
        Args:
            in_feats (int): number input features
            out_dim (int): num output classes
            h_feats (int, optional): num hidden features. Defaults to 16.
            **kwargs (:class:`torch_geometric.nn.conv.MessagePassing`):
                Additional arguments for all GCNConv layers
        """
        super().__init__()
[docs]
        self.is_training = is_training 
[docs]
        self.should_normalize = should_l2_normalize_output 
        remaining_kwargs, discarded_kwargs = pyg_utils.filter_dict(
            input_dict=kwargs,
            keys_to_keep=pyg_utils.MESSAGE_PASSING_BASE_CLS_ARGS
            + [
                "improved",
                "cached",
                "add_self_loops",
                "normalize",
                "bias",
            ],
        )
        logger.info(
            f"Discarded kwargs for {GCNConv.__name__}: {discarded_kwargs.keys()}"
        )
[docs]
        self.conv1 = GCNConv(
            in_channels=in_dim, out_channels=hid_dim, **remaining_kwargs
        ) 
[docs]
        self.conv2 = GCNConv(
            in_channels=hid_dim, out_channels=out_dim, **remaining_kwargs
        ) 
[docs]
    def forward(self, data: torch_geometric.data.Data) -> torch.Tensor:
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.is_training)
        x = self.conv2(x, edge_index)
        if self.should_normalize:
            x = F.normalize(x, p=2, dim=1)
        return x 
    @property
[docs]
    def graph_backend(self) -> GraphBackend:
        return GraphBackend.PYG