from typing import Optional
import torch
import torch_geometric.data
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import Linear
from gigl.src.common.models.layers.normalization import l2_normalize_embeddings
from gigl.src.common.models.pyg.nn.conv.hgt_conv import HGTConv
from gigl.src.common.models.pyg.nn.conv.simplehgn_conv import SimpleHGNConv
from gigl.src.common.models.pyg.nn.models.feature_embedding import FeatureEmbeddingLayer
from gigl.src.common.models.utils.torch import to_hetero_feat
from gigl.src.common.types.graph_data import EdgeType, NodeType
# HGT acts as a soft template for future Heterogeneous GNN model init and forwarding implementation.
[docs]
class HGT(nn.Module):
    """
    Heterogeneous Graph Transformer model. Paper: https://arxiv.org/pdf/2003.01332.pdf
    This implementation is based on the example of:
    https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hgt_dblp.py
    Args:
        node_type_to_feat_dim_map (dict[NodeType, int]): Dictionary mapping node types to their input dimensions.
        edge_type_to_feat_dim_map (dict[EdgeType, int]): Dictionary mapping node types to their feature dimensions.
        hid_dim (int): Hidden dimension size.
        out_dim (int, optional): Output dimension size. Defaults to 128.
        num_layers (int, optional): Number of layers. Defaults to 2.
        num_heads (int, optional): Number of attention heads. Defaults to 2.
    """
    def __init__(
        self,
        node_type_to_feat_dim_map: dict[NodeType, int],
        edge_type_to_feat_dim_map: dict[EdgeType, int],
        hid_dim: int,
        out_dim: int = 128,
        num_layers: int = 2,
        num_heads: int = 2,
        should_l2_normalize_embedding_layer_output: bool = False,
        feature_embedding_layers: Optional[
            dict[NodeType, FeatureEmbeddingLayer]
        ] = None,
        **kwargs,
    ):
        super().__init__()
        self._node_types = list(node_type_to_feat_dim_map.keys())
        self._edge_types = list(edge_type_to_feat_dim_map.keys())
[docs]
        self.lin_dict = torch.nn.ModuleDict() 
        for node_type, in_dim in node_type_to_feat_dim_map.items():
            self.lin_dict[node_type] = Linear(in_channels=in_dim, out_channels=hid_dim)
[docs]
        self.convs = torch.nn.ModuleList() 
        for _ in range(num_layers):
            conv = HGTConv(
                in_channels=hid_dim,
                out_channels=hid_dim,
                metadata=(self._node_types, self._edge_types),
                heads=num_heads,
            )
            self.convs.append(conv)
[docs]
        self.lin = Linear(in_channels=hid_dim, out_channels=out_dim) 
[docs]
        self.should_l2_normalize_embedding_layer_output = (
            should_l2_normalize_embedding_layer_output
        ) 
[docs]
        self.feature_embedding_layers = feature_embedding_layers 
[docs]
    def forward(
        self,
        data: torch_geometric.data.hetero_data.HeteroData,
        output_node_types: list[NodeType],
        device: torch.device,
    ) -> dict[NodeType, torch.Tensor]:
        """
        Runs the forward pass of the module
        Args:
            data (torch_geometric.data.hetero_data.HeteroData): Input HeteroData object.
            output_node_types (list[NodeType]): List of node types for which to return the output embeddings.
        Returns:
            dict[NodeType, torch.Tensor]: Dictionary with node types as keys and output tensors as values.
        """
        node_type_to_features_dict = data.x_dict
        if self.feature_embedding_layers:
            node_type_to_features_dict = {
                node_type: self.feature_embedding_layers[node_type](x)
                if node_type in self.feature_embedding_layers
                else x
                for node_type, x in node_type_to_features_dict.items()
            }
        # When we initialize a HGTConv layer, we provide some edge types, which it uses to create an offset mapping for
        # each edge type. When we forward some data.edge_index_dict through this layer, we require that the edge types
        # there have the same order as the edge types in the constructor, otherwise the offsets will be off. For large graphs,
        # this will lead to indexing errors during segmented matrix multiplication. For smaller graphs, segmented matrix
        # multiplication is not used (based on some heuristic in PyG) and we don't observe this error. However, the indices
        # are still wrong and likely lead to incorrect forward passes, hurting the model performance.
        if sorted(self._edge_types) != sorted(data.edge_index_dict.keys()):
            raise ValueError(
                f"Found mismatching edge types between HGTConv initialized edge types {self._edge_types} and HeteroData edge types {sorted(data.edge_index_dict)}. These must be the same."
            )
        edge_index_dict = {
            edge_type: data.edge_index_dict[edge_type] for edge_type in self._edge_types
        }
        node_type_to_features_dict = {
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in node_type_to_features_dict.items()
        }
        for conv in self.convs:
            node_type_to_features_dict = conv(
                node_type_to_features_dict, edge_index_dict
            )
        node_typed_embeddings: dict[NodeType, torch.Tensor] = {}
        for node_type in output_node_types:
            node_typed_embeddings[node_type] = (
                self.lin(node_type_to_features_dict[node_type])
                if node_type in node_type_to_features_dict
                else torch.FloatTensor([]).to(device=device)
            )
        if self.should_l2_normalize_embedding_layer_output:
            node_typed_embeddings = l2_normalize_embeddings(  # type: ignore
                node_typed_embeddings=node_typed_embeddings
            )
        return node_typed_embeddings 
 
[docs]
class SimpleHGN(nn.Module):
    def __init__(
        self,
        node_type_to_feat_dim_map: dict[NodeType, int],
        edge_type_to_feat_dim_map: dict[EdgeType, int],
        node_hid_dim: int,
        edge_hid_dim: int,
        edge_type_dim: int,
        node_out_dim: int = 128,
        num_layers: int = 2,
        num_heads: int = 2,
        should_use_node_residual: bool = True,
        negative_slope: float = 0.2,
        dropout: float = 0.0,
        activation=F.elu,
        should_l2_normalize_embedding_layer_output: bool = False,
        **kwargs,
    ):
        """
        SimpleHGN layer from the paper: https://arxiv.org/pdf/2112.14936
        Args:
            node_type_to_feat_dim_map (dict[NodeType, int]): Dictionary mapping node types to their input dimensions.
            edge_type_to_feat_dim_map (dict[EdgeType, int]): Dictionary mapping edge types to their feature dimensions.
            node_hid_dim (int): Hidden dimension size for node features.
            edge_hid_dim (int): Hidden dimension size for edge features.
            edge_type_dim (int): Hidden dimension size for edge types.
            node_out_dim (int): Output dimension size for node features. Defaults to 128.
            num_layers (int): Number of layers. Defaults to 2.
            num_heads (int): Number of attention heads. Defaults to 2.
            should_use_node_residual (bool): Whether to use node residual. Defaults to True.
            negative_slope (float): Negative slope used in the LeakyReLU. Defaults to 0.2.
            dropout (float): Dropout rate. Defaults to 0.0.
            activation: Activation function. Defaults to `F.elu`.
        """
        super().__init__()
[docs]
        self.num_layers = num_layers 
[docs]
        self.should_l2_normalize_embedding_layer_output = (
            should_l2_normalize_embedding_layer_output
        ) 
        # Used to project all node and edge types to compatible dimensions (node_hid_dim and edge_hid_dim, resp.)
[docs]
        self.node_type_lin_dict = torch.nn.ModuleDict() 
        for node_type, in_dim in node_type_to_feat_dim_map.items():
            self.node_type_lin_dict[str(node_type)] = nn.Linear(
                in_features=in_dim, out_features=node_hid_dim
            )
        # Used to project all edge types to compatible dimensions (edge_hid_dim)
        # if edge features are present, else None.
[docs]
        self.should_have_edge_features: bool = any(edge_type_to_feat_dim_map.values()) 
[docs]
        self.edge_type_lin_dict = torch.nn.ModuleDict() 
        for edge_type, in_dim in edge_type_to_feat_dim_map.items():
            if in_dim:
                self.edge_type_lin_dict[str(edge_type)] = nn.Linear(
                    in_features=in_dim, out_features=edge_hid_dim
                )
[docs]
        self.convs = torch.nn.ModuleList() 
        for layer_id in range(num_layers):
            conv = SimpleHGNConv(
                in_channels=node_hid_dim if layer_id == 0 else node_hid_dim * num_heads,
                edge_in_channels=(
                    edge_hid_dim if self.should_have_edge_features else None
                ),
                edge_type_dim=edge_type_dim,
                out_channels=node_hid_dim,
                num_heads=num_heads,
                num_edge_types=len(edge_type_to_feat_dim_map),
                should_use_node_residual=should_use_node_residual,
                negative_slope=negative_slope,
                dropout=dropout,
            )
            self.convs.append(conv)
[docs]
        self.lin = nn.Linear(
            in_features=node_hid_dim * num_heads, out_features=node_out_dim
        ) 
[docs]
        self.activation = activation 
[docs]
    def forward(
        self,
        data: torch_geometric.data.hetero_data.HeteroData,
        output_node_types: list[NodeType],
        device: torch.device,
    ) -> dict[NodeType, torch.Tensor]:
        # Align dimensions across all node-types and all edge-types, resp.
        x_dict = {
            node_type: self.node_type_lin_dict[node_type](x)
            for node_type, x in data.x_dict.items()
        }
        init_dict = {
            edge_type: {
                "edge_index": data.edge_index_dict[edge_type],
            }
            for edge_type in data.edge_index_dict.keys()
        }
        for edge_type in data.edge_types:
            maybe_edge_attr = getattr(data[edge_type], "edge_attr", None)
            if isinstance(maybe_edge_attr, torch.Tensor):
                init_dict[edge_type].update(
                    {
                        "edge_attr": self.edge_type_lin_dict[
                            f"{edge_type[0]}-{edge_type[1]}-{edge_type[2]}"
                        ](maybe_edge_attr)
                    }
                )
        init_dict.update({node_type: {"x": x} for node_type, x in x_dict.items()})
        # Convert hetero to homo graph, so we can pass around homo graph info to conv forwards.
        projected_hetero_data = torch_geometric.data.hetero_data.HeteroData(init_dict)
        projected_homo_data = projected_hetero_data.to_homogeneous()
        h = projected_homo_data.x
        for layer_id, conv in enumerate(self.convs):
            h = conv(
                edge_index=projected_homo_data.edge_index,
                node_feat=h,
                edge_feat=(
                    projected_homo_data.edge_attr
                    if self.should_have_edge_features
                    else None
                ),
                edge_type=projected_homo_data.edge_type,
            )
            if layer_id != self.num_layers - 1:
                h = self.activation(h)
        # Project to node output dim
        embeddings = self.lin(h)
        node_typed_embeddings = to_hetero_feat(
            h=embeddings,
            type_indices=projected_homo_data.node_type,
            types=projected_hetero_data.node_types,
        )
        for node_type in output_node_types:
            if node_type not in node_typed_embeddings:
                raise ValueError(
                    f"Requested node type {node_type} does not exist in output tensor."
                )
        if self.should_l2_normalize_embedding_layer_output:
            node_typed_embeddings = l2_normalize_embeddings(  # type: ignore
                node_typed_embeddings=node_typed_embeddings
            )
        return node_typed_embeddings