from typing import Optional, Union
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from torch_geometric.data import Data, HeteroData
from torch_geometric.nn.conv import LGConv
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from typing_extensions import Self
from gigl.src.common.types.graph_data import NodeType
from gigl.types.graph import to_heterogeneous_node
[docs]
class LinkPredictionGNN(nn.Module):
    """
    Link Prediction GNN model for both homogeneous and heterogeneous use cases
    Args:
        encoder (nn.Module): Either BasicGNN or Heterogeneous GNN for generating embeddings
        decoder (nn.Module): Decoder for transforming embeddings into scores.
            Recommended to use `gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder`
    """
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
    ) -> None:
        super().__init__()
        self._encoder = encoder
        self._decoder = decoder
[docs]
    def forward(
        self,
        data: Union[Data, HeteroData],
        device: torch.device,
        output_node_types: Optional[list[NodeType]] = None,
    ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]:
        if isinstance(data, HeteroData):
            if output_node_types is None:
                raise ValueError(
                    "Output node types must be specified in forward() pass for heterogeneous model"
                )
            return self._encoder(
                data=data, output_node_types=output_node_types, device=device
            )
        else:
            return self._encoder(data=data, device=device) 
[docs]
    def decode(
        self,
        query_embeddings: torch.Tensor,
        candidate_embeddings: torch.Tensor,
    ) -> torch.Tensor:
        return self._decoder(
            query_embeddings=query_embeddings,
            candidate_embeddings=candidate_embeddings,
        ) 
    @property
[docs]
    def encoder(self) -> nn.Module:
        return self._encoder 
    @property
[docs]
    def decoder(self) -> nn.Module:
        return self._decoder 
[docs]
    def to_ddp(
        self,
        device: Optional[torch.device],
        find_unused_encoder_parameters: bool = False,
    ) -> Self:
        """
        Converts the model to DistributedDataParallel (DDP) mode.
        We do this because DDP does *not* expect the forward method of the modules it wraps to be called directly.
        See how DistributedDataParallel.forward calls _pre_forward:
        https://github.com/pytorch/pytorch/blob/26807dcf277feb2d99ab88d7b6da526488baea93/torch/nn/parallel/distributed.py#L1657
        If we do not do this, then calling forward() on the individual modules may not work correctly.
        Calling this function makes it safe to do: `LinkPredictionGNN.decoder(data, device)`
        Args:
            device (Optional[torch.device]): The device to which the model should be moved.
                If None, will default to CPU.
            find_unused_encoder_parameters (bool): Whether to find unused parameters in the model.
                This should be set to True if the model has parameters that are not used in the forward pass.
        Returns:
            LinkPredictionGNN: A new instance of LinkPredictionGNN for use with DDP.
        """
        if device is None:
            device = torch.device("cpu")
        ddp_encoder = DistributedDataParallel(
            self._encoder.to(device),
            device_ids=[device] if device.type != "cpu" else None,
            find_unused_parameters=find_unused_encoder_parameters,
        )
        # Do this "backwards" so the we can define "ddp_decoder" as a nn.Module first...
        if not any(p.requires_grad for p in self._decoder.parameters()):
            # If the decoder has no trainable parameters, we can just use it as is
            ddp_decoder = self._decoder.to(device)
        else:
            # Only wrap the decoder in DDP if it has parameters that require gradients
            # Otherwise DDP will complain about no parameters to train.
            ddp_decoder = DistributedDataParallel(
                self._decoder.to(device),
                device_ids=[device] if device.type != "cpu" else None,
            )
        self._encoder = ddp_encoder
        self._decoder = ddp_decoder
        return self 
[docs]
    def unwrap_from_ddp(self) -> "LinkPredictionGNN":
        """
        Unwraps the model from DistributedDataParallel if it is wrapped.
        Returns:
            LinkPredictionGNN: A new instance of LinkPredictionGNN with the original encoder and decoder.
        """
        if isinstance(self._encoder, DistributedDataParallel):
            encoder = self._encoder.module
        else:
            encoder = self._encoder
        if isinstance(self._decoder, DistributedDataParallel):
            decoder = self._decoder.module
        else:
            decoder = self._decoder
        return LinkPredictionGNN(encoder=encoder, decoder=decoder) 
 
# TODO(swong3): Move specific models to gigl.nn.models whenever we restructure model placement.
# TODO(swong3): Abstract TorchRec functionality, and make this LightGCN specific
# TODO(swong3): Remove device context from LightGCN module (use meta, but will have to figure out how to handle buffer transfer)
[docs]
class LightGCN(nn.Module):
    """
    LightGCN model with TorchRec integration for distributed ID embeddings.
    Reference: https://arxiv.org/pdf/2002.02126
    This class extends the basic LightGCN implementation to use TorchRec's
    distributed embedding tables for handling large-scale ID embeddings.
    Args:
        node_type_to_num_nodes (Union[int, Dict[NodeType, int]]): Map from node types
            to node counts. Can also pass a single int for homogeneous graphs.
        embedding_dim (int): Dimension of node embeddings D. Default: 64.
        num_layers (int): Number of LightGCN propagation layers K. Default: 2.
        device (torch.device): Device to run the computation on. Default: CPU.
        layer_weights (Optional[List[float]]): Weights for [e^(0), e^(1), ..., e^(K)].
            Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None.
    """
    def __init__(
        self,
        node_type_to_num_nodes: Union[int, dict[NodeType, int]],
        embedding_dim: int = 64,
        num_layers: int = 2,
        device: torch.device = torch.device("cpu"),
        layer_weights: Optional[list[float]] = None,
    ):
        super().__init__()
        self._node_type_to_num_nodes = to_heterogeneous_node(node_type_to_num_nodes)
        self._embedding_dim = embedding_dim
        self._num_layers = num_layers
        self._device = device
        # Construct LightGCN α weights: include e^(0) + K propagated layers ==> K+1 weights
        if layer_weights is None:
            layer_weights = [1.0 / (num_layers + 1)] * (num_layers + 1)
        else:
            if len(layer_weights) != (num_layers + 1):
                raise ValueError(
                    f"layer_weights must have length K+1={num_layers+1}, got {len(layer_weights)}"
                )
        # Register layer weights as a buffer so it moves with the model to different devices
        self.register_buffer(
            "_layer_weights",
            torch.tensor(layer_weights, dtype=torch.float32),
        )
        # Build TorchRec EBC (one table per node type)
        # feature key naming convention: f"{node_type}_id"
        self._feature_keys: list[str] = [
            f"{node_type}_id" for node_type in self._node_type_to_num_nodes.keys()
        ]
        tables: list[EmbeddingBagConfig] = []
        for node_type, num_nodes in self._node_type_to_num_nodes.items():
            tables.append(
                EmbeddingBagConfig(
                    name=f"node_embedding_{node_type}",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_nodes,
                    feature_names=[f"{node_type}_id"],
                )
            )
        self._embedding_bag_collection = EmbeddingBagCollection(
            tables=tables, device=self._device
        )
        # Construct LightGCN propagation layers (LGConv = Ā X)
        self._convs = nn.ModuleList(
            [LGConv() for _ in range(self._num_layers)]
        )  # K layers
[docs]
    def forward(
        self,
        data: Union[Data, HeteroData],
        device: torch.device,
        output_node_types: Optional[list[NodeType]] = None,
        anchor_node_ids: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]:
        """
        Forward pass of the LightGCN model.
        Args:
            data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous).
            device (torch.device): Device to run the computation on.
            output_node_types (Optional[List[NodeType]]): List of node types to return
                embeddings for. Required for heterogeneous graphs. Default: None.
            anchor_node_ids (Optional[torch.Tensor]): Local node indices to return
                embeddings for. If None, returns embeddings for all nodes. Default: None.
        Returns:
            Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: Node embeddings.
                For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim].
                For heterogeneous graphs, returns dict mapping node types to embeddings.
        """
        if isinstance(data, HeteroData):
            raise NotImplementedError("HeteroData is not yet supported for LightGCN")
            output_node_types = output_node_types or list(data.node_types)
            return self._forward_heterogeneous(
                data, device, output_node_types, anchor_node_ids
            )
        else:
            return self._forward_homogeneous(data, device, anchor_node_ids) 
    def _forward_homogeneous(
        self,
        data: Data,
        device: torch.device,
        anchor_node_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass for homogeneous graphs using LightGCN propagation.
        Notation follows the LightGCN paper (https://arxiv.org/pdf/2002.02126):
        - e^(0): Initial embeddings (no propagation)
        - e^(k): Embeddings after k layers of graph convolution
        - z: Final embedding = weighted sum of [e^(0), e^(1), ..., e^(K)]
        Variable naming:
        - embeddings_0: Initial embeddings e^(0) for subgraph nodes
        - embeddings_k: Current layer embeddings during propagation
        - all_layer_embeddings: List containing [e^(0), e^(1), ..., e^(K)]
        - final_embeddings: Final node embeddings (weighted sum)
        Args:
            data (Data): PyG Data object containing edge_index and node IDs.
            device (torch.device): Device to run computation on.
            anchor_node_ids (Optional[torch.Tensor]): Local node indices to return
                embeddings for. If None, returns embeddings for all nodes. Default: None.
        Returns:
            torch.Tensor: Tensor of shape [num_nodes, embedding_dim] containing
                final LightGCN embeddings.
        """
        # Check if model is setup to be homogeneous
        assert len(self._feature_keys) == 1, (
            f"Homogeneous path expects exactly one node type; got "
            f"{len(self._feature_keys)} types: {self._feature_keys}"
        )
        key = self._feature_keys[0]
        edge_index = data.edge_index.to(
            device
        )  # shape [2, E], where E is the number of edges
        assert hasattr(
            data, "node"
        ), "Subgraph must include .node to map local→global IDs."
        global_ids = data.node.to(
            device
        ).long()  # shape [N_sub], maps local 0..N_sub-1 → global ids
        embeddings_0 = self._lookup_embeddings_for_single_node_type(
            key, global_ids
        )  # shape [N_sub, D], where N_sub is number of nodes in subgraph and D is embedding_dim
        all_layer_embeddings: list[torch.Tensor] = [embeddings_0]
        embeddings_k = embeddings_0
        for conv in self._convs:
            embeddings_k = conv(
                embeddings_k, edge_index
            )  # shape [N_sub, D], normalized neighbor averaging over *subgraph* edges
            all_layer_embeddings.append(embeddings_k)
        final_embeddings = self._weighted_layer_sum(
            all_layer_embeddings
        )  # shape [N_sub, D], weighted sum of all layer embeddings
        # If anchor node ids are provided, return the embeddings for the anchor nodes only
        if anchor_node_ids is not None:
            anchors_local = anchor_node_ids.to(device).long()  # shape [num_anchors]
            return final_embeddings[
                anchors_local
            ]  # shape [num_anchors, D], embeddings for anchor nodes only
        # Otherwise, return the embeddings for all nodes in the subgraph
        return (
            final_embeddings  # shape [N_sub, D], embeddings for all nodes in subgraph
        )
    def _lookup_embeddings_for_single_node_type(
        self, node_type: str, ids: torch.Tensor
    ) -> torch.Tensor:
        """
        Fetch per-ID embeddings for a single node type using EmbeddingBagCollection.
        This method constructs a KeyedJaggedTensor (KJT) that includes all EBC feature
        keys to ensure consistent forward pass behavior. For the requested node type,
        we create B bags of length 1 (one per ID). For all other node types, we create
        B bags of length 0. With SUM pooling, non-requested node types contribute zeros
        and the requested node type acts as identity lookup.
        Args:
            node_type (str): Feature key for the node type (e.g., "user_id", "item_id").
            ids (torch.Tensor): Node IDs to look up, shape [batch_size].
        Returns:
            torch.Tensor: Embeddings for the requested node type, shape [batch_size, embedding_dim].
        """
        if node_type not in self._feature_keys:
            raise KeyError(
                f"Unknown feature key '{node_type}'. Valid keys: {self._feature_keys}"
            )
        # Number of examples (one ID per "bag")
        batch_size = int(ids.numel())  # B is the number of node IDs to lookup
        device = ids.device
        # Build lengths in key-major order: for each key, we give B lengths.
        # - requested key: ones (each example has 1 id)
        # - other keys: zeros (each example has 0 ids)
        lengths_per_key: list[torch.Tensor] = []
        for nt in self._feature_keys:
            if nt == node_type:
                lengths_per_key.append(
                    torch.ones(batch_size, dtype=torch.long, device=device)
                )  # shape [B], all ones for requested key
            else:
                lengths_per_key.append(
                    torch.zeros(batch_size, dtype=torch.long, device=device)
                )  # shape [B], all zeros for other keys
        lengths = torch.cat(
            lengths_per_key, dim=0
        )  # shape [batch_size * num_keys], concatenated lengths for all keys
        # Values only contain the requested key's ids (sum of other lengths is 0)
        kjt = KeyedJaggedTensor(
            keys=self._feature_keys,  # include ALL keys known by EBC
            values=ids.long(),  # shape [batch_size], only batch_size values for the requested key
            lengths=lengths,  # shape [batch_size * num_keys], batch_size lengths per key, concatenated key-major
        )
        out = self._embedding_bag_collection(
            kjt
        )  # KeyedTensor (dict-like): out[key] -> [batch_size, D]
        return out[node_type]  # shape [batch_size, D], embeddings for the requested key
    def _weighted_layer_sum(
        self, all_layer_embeddings: list[torch.Tensor]
    ) -> torch.Tensor:
        """
        Computes weighted sum: w_0 * e^(0) + w_1 * e^(1) + ... + w_K * e^(K).
        This implements the final aggregation step in LightGCN where embeddings from
        all layers (including the initial e^(0)) are combined using learned weights.
        Args:
            all_layer_embeddings (List[torch.Tensor]): List [e^(0), e^(1), ..., e^(K)]
                where each tensor has shape [N, D].
        Returns:
            torch.Tensor: Weighted sum of all layer embeddings, shape [N, D].
        """
        if len(all_layer_embeddings) != len(self._layer_weights):
            raise ValueError(
                f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights."
            )
        # Stack all layer embeddings and compute weighted sum
        # _layer_weights is already a tensor buffer registered in __init__
        stacked = torch.stack(all_layer_embeddings, dim=0)  # shape [K+1, N, D]
        w = self._layer_weights.to(stacked.device)  # shape [K+1], ensure on same device
        out = (stacked * w.view(-1, 1, 1)).sum(
            dim=0
        )  # shape [N, D], w_0*X_0 + w_1*X_1 + ...
        return out