from __future__ import annotations
from dataclasses import dataclass
from functools import partial
import torch
import torchrec
from gigl.experimental.knowledge_graph_embedding.common.graph_dataset import (
    CONDENSED_EDGE_TYPE_FIELD,
    DST_FIELD,
    SRC_FIELD,
    HeterogeneousGraphEdgeDict,
)
from gigl.experimental.knowledge_graph_embedding.common.torchrec.batch import (
    DataclassBatch,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.dataloader import (
    DataloaderConfig,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.sampling import (
    SamplingConfig,
)
from gigl.src.common.types.graph_data import (
    CondensedEdgeType,
    CondensedNodeType,
    NodeType,
)
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.training.v1.lib.data_loaders.tf_records_iterable_dataset import (
    LoopyIterableDataset,
)
@dataclass
[docs]
class EdgeBatch(DataclassBatch):
    """
    A class for representing a batch of edges in a heterogeneous graph.
    This can be derived from input edge tensors, and contains logic to build
    a torchrec KeyedJaggedTensor (used for sharded embedding lookups) and
    other metadata tensors which are required to train KGE models.
    """
[docs]
    src_dst_pairs: torchrec.KeyedJaggedTensor 
[docs]
    condensed_edge_types: torch.Tensor 
    @staticmethod
[docs]
    def from_edge_tensors(
        edges: torch.Tensor,
        condensed_edge_types: torch.Tensor,
        edge_labels: torch.Tensor,
        condensed_node_type_to_node_type_map: dict[CondensedNodeType, NodeType],
        condensed_edge_type_to_condensed_node_type_map: dict[
            CondensedEdgeType, tuple[CondensedNodeType, CondensedNodeType]
        ],
    ) -> EdgeBatch:
        """
        Creates an EdgeBatch from edge tensors.
        We create an EdgeBatch of len(2 * edges) by creating a src-dst pair for each edge in the batch.
        Args:
            edges (torch.Tensor): A tensor of edges.
            condensed_edge_types (torch.Tensor): A tensor of condensed edge types.
            edge_labels (torch.Tensor): A tensor of edge labels.
            condensed_node_type_to_node_type_map (dict[CondensedNodeType, NodeType]): A mapping from condensed node types to node types.
            condensed_edge_type_to_condensed_node_type_map (dict[CondensedEdgeType, tuple[CondensedNodeType, CondensedNodeType]]): A mapping from condensed edge types to condensed node types.
        """
        num_edges = edges.size(0)
        # We canonicalize the order of keys so all KJTs are constructed the same way.
        # This ensures that when they are processed by EmbeddingBagCollections, the outputs are consistently ordered.
        cnt_keys = sorted(list(condensed_node_type_to_node_type_map.keys()))
        lengths: dict[CondensedNodeType, list[int]] = {
            key: [0] * (2 * num_edges) for key in cnt_keys
        }
        values: dict[CondensedNodeType, list[int]] = {key: [] for key in cnt_keys}
        for i, (edge, condensed_edge_type) in enumerate(
            zip(edges, condensed_edge_types)
        ):
            src, dst = edge[0].item(), edge[1].item()
            src_cnt, dst_cnt = condensed_edge_type_to_condensed_node_type_map[
                condensed_edge_type.item()
            ]
            values[src_cnt].append(src)
            values[dst_cnt].append(dst)
            lengths[src_cnt][2 * i] = 1
            lengths[dst_cnt][2 * i + 1] = 1
        lengths_tensor: dict[CondensedNodeType, torch.Tensor] = dict()
        values_tensor: dict[CondensedNodeType, torch.Tensor] = dict()
        for key in cnt_keys:
            lengths_tensor[key] = torch.tensor(lengths[key], dtype=torch.int32)
            values_tensor[key] = torch.tensor(values[key], dtype=torch.int32)
        # Flatten tensors
        src_dst_pairs = torchrec.KeyedJaggedTensor(
            keys=cnt_keys,
            values=torch.cat([values_tensor[cnt] for cnt in cnt_keys], dim=0),
            lengths=torch.cat([lengths_tensor[cnt] for cnt in cnt_keys], dim=0),
        )
        return EdgeBatch(
            src_dst_pairs=src_dst_pairs,
            condensed_edge_types=condensed_edge_types,
            labels=edge_labels,
        ) 
[docs]
    def to_edge_tensors(
        self,
        condensed_edge_type_to_condensed_node_type_map: dict[
            CondensedEdgeType, tuple[CondensedNodeType, CondensedNodeType]
        ],
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Reconstructs the edge tensors from the EdgeBatch.
        This is used for debugging and sanity checking the EdgeBatch.
        """
        # Get the edge tensors from the edge batch
        src_dst_pairs_kjt = self.src_dst_pairs
        condensed_edge_types = self.condensed_edge_types
        edge_labels = self.labels
        node_types = src_dst_pairs_kjt.keys()  # the unique node types
        num_node_types = len(
            node_types
        )  # the num of unique node types (= num of embedding tables)
        num_edges = int(
            len(src_dst_pairs_kjt.lengths()) / num_node_types / 2
        )  # len(lengths) == 2 * num_node_types * num_edges
        assert (
            num_edges == len(edge_labels) == len(condensed_edge_types)
        ), f"The number of edges, edge labels and edge types should be equal.  Got {num_edges, len(edge_labels), len(condensed_edge_types)}"
        reconstructed_edges = []
        src_dst_pairs_kjt_view = src_dst_pairs_kjt.lengths().view(
            num_node_types, num_edges * 2
        )
        condensed_node_types_for_edges = src_dst_pairs_kjt_view.argmax(dim=0).view(
            -1, 2
        )
        for condensed_edge_type, condensed_node_types_in_edges in zip(
            condensed_edge_types, condensed_node_types_for_edges
        ):
            (
                expected_src_cnt,
                expected_dst_cnt,
            ) = condensed_edge_type_to_condensed_node_type_map[
                condensed_edge_type.item()
            ]
            assert (
                condensed_node_types_in_edges[0].item() == expected_src_cnt
                and condensed_node_types_in_edges[1].item() == expected_dst_cnt
            ), f"Expected condensed node types for edge type {condensed_edge_type} to be {expected_src_cnt, expected_dst_cnt}, but got {condensed_node_types_in_edges}"
        condensed_node_types_for_edges = src_dst_pairs_kjt_view.argmax(dim=0).tolist()
        src_dst_pairs_values_iters = {
            node_type: iter(jagged.values())
            for node_type, jagged in src_dst_pairs_kjt.to_dict().items()
        }
        for condensed_node_type in condensed_node_types_for_edges:
            reconstructed_edges.append(
                next(src_dst_pairs_values_iters[condensed_node_type])
            )
        reconstructed_edges_tensor = torch.tensor(
            reconstructed_edges, dtype=torch.int32
        )
        reconstructed_edges_tensor = reconstructed_edges_tensor.view(-1, 2)
        return reconstructed_edges_tensor, condensed_edge_types, edge_labels 
    @staticmethod
[docs]
    def build_data_loader(
        dataset: torch.utils.data.IterableDataset,
        sampling_config: SamplingConfig,
        dataloader_config: DataloaderConfig,
        graph_metadata: GraphMetadataPbWrapper,
        condensed_node_type_to_vocab_size_map: dict[CondensedNodeType, int],
        pin_memory: bool,
        should_loop: bool = True,
    ):
        dataset = (
            LoopyIterableDataset(iterable_dataset=dataset) if should_loop else dataset
        )
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=sampling_config.positive_edge_batch_size,
            collate_fn=partial(
                collate_edge_batch_from_heterogeneous_graph_edge_dict,
                condensed_edge_type_to_condensed_node_type_map=graph_metadata.condensed_edge_type_to_condensed_node_types,
                condensed_node_type_to_vocab_size_map=condensed_node_type_to_vocab_size_map,
                condensed_node_type_to_node_type_map=graph_metadata.condensed_node_type_to_node_type_map,
                num_random_negatives_per_edge=sampling_config.num_random_negatives_per_edge,
            ),
            pin_memory=pin_memory,
            num_workers=dataloader_config.num_workers,
        ) 
 
[docs]
def collate_edge_batch_from_heterogeneous_graph_edge_dict(
    inputs: list[HeterogeneousGraphEdgeDict],
    condensed_edge_type_to_condensed_node_type_map: dict[
        CondensedEdgeType, tuple[CondensedNodeType, CondensedNodeType]
    ],
    condensed_node_type_to_vocab_size_map: dict[CondensedNodeType, int],
    condensed_node_type_to_node_type_map: dict[CondensedNodeType, NodeType],
    num_random_negatives_per_edge: int = 0,
) -> EdgeBatch:
    """
    This is a collate function for the EdgeBatch.
    It takes a list of heterogeneous graph edge dictionaries (read from upstream dataset),
    converts them to tensors for "positive" edges, samples "negative" edges if applicable,
    and constructs an EdgeBatch (containing a TorchRec KeyedJaggedTensor and metadata).
    Args:
        inputs (list[HeterogeneousGraphEdgeDict]): The input data.
        condensed_edge_type_to_condensed_node_type_map (dict[CondensedEdgeType, tuple[CondensedNodeType, CondensedNodeType]]): A mapping from condensed edge types to condensed node types.
        condensed_node_type_to_vocab_size_map (dict[CondensedNodeType, int]): A mapping from condensed node types to vocab sizes.
        condensed_node_type_to_node_type_map (dict[CondensedNodeType, NodeType]): A mapping from condensed node types to node types.
        num_negative_samples_per_edge (int): The number of negative samples to generate for each positive edge.
    Returns:
        EdgeBatch: The collated EdgeBatch.
    """
    # Convert the input data to tensors
    pos_edges, pos_condensed_edge_types, pos_labels = build_tensors_from_edge_dicts(
        inputs
    )
    # Generative negative edges for the positive edges.
    if num_random_negatives_per_edge == 0:
        # If no negative samples are required, return the positive edges only.
        neg_edges = torch.empty((0, 2), dtype=torch.int32)
        neg_condensed_edge_types = torch.empty(0, dtype=torch.int32)
        neg_labels = torch.empty(0, dtype=torch.int32)
    else:
        (
            neg_edges,
            neg_condensed_edge_types,
            neg_labels,
        ) = relationwise_batch_random_negative_sampling(
            condensed_edge_type_to_condensed_node_type_map=condensed_edge_type_to_condensed_node_type_map,
            condensed_node_type_to_vocab_size_map=condensed_node_type_to_vocab_size_map,
            num_negatives_per_condensed_edge_type=num_random_negatives_per_edge,
        )
    # Construct the EdgeBatch which the model will consume.
    edge_batch = EdgeBatch.from_edge_tensors(
        edges=torch.vstack((pos_edges, neg_edges)),
        condensed_edge_types=torch.hstack(
            (pos_condensed_edge_types, neg_condensed_edge_types)
        ),
        edge_labels=torch.hstack((pos_labels, neg_labels)),
        condensed_node_type_to_node_type_map=condensed_node_type_to_node_type_map,
        condensed_edge_type_to_condensed_node_type_map=condensed_edge_type_to_condensed_node_type_map,
    )
    return edge_batch 
[docs]
def build_tensors_from_edge_dicts(
    inputs: list[HeterogeneousGraphEdgeDict],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Converts a list of HeterogeneousGraphEdgeDict into tensors.
    Args:
        inputs (list[HeterogeneousGraphEdgeDict]): A list of edge dictionaries.
    Returns:
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
            - edges (torch.Tensor): A tensor of shape [num_edges, 2] containing the source and destination node IDs.
            - condensed_edge_types (torch.Tensor): A tensor of shape [num_edges] containing the condensed edge types.
            - labels (torch.Tensor): A tensor of shape [num_edges] containing labels (all set to 1).
    """
    # Determine the number of edges
    num_edges = len(inputs)
    # Preallocate torch tensors
    edges = torch.empty((num_edges, 2), dtype=torch.int32)
    condensed_edge_types = torch.empty(num_edges, dtype=torch.int32)
    # Fill the preallocated torch tensors using direct indexing
    for i, row in enumerate(inputs):
        edges[i, 0] = int(row[SRC_FIELD])
        edges[i, 1] = int(row[DST_FIELD])
        condensed_edge_types[i] = int(row[CONDENSED_EDGE_TYPE_FIELD])
    # Create labels tensor directly
    labels = torch.ones(num_edges, dtype=torch.int32)
    return edges, condensed_edge_types, labels 
[docs]
def relationwise_batch_random_negative_sampling(
    condensed_edge_type_to_condensed_node_type_map: dict[
        CondensedEdgeType, tuple[CondensedNodeType, CondensedNodeType]
    ],
    condensed_node_type_to_vocab_size_map: dict[CondensedNodeType, int],
    num_negatives_per_condensed_edge_type: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Performs random negative sampling for each edge type.
    This function generates `num_negatives_per_condensed_edge_type` with src and dst selected at random from the
    vocabulary associated with the node types, as defined by the edge type and provided type-to-vocabulary maps.
    These can be consumed in model training as negative samples which are shared across edges.
    Args:
        condensed_edge_type_to_condensed_node_type_map (dict[int, tuple[int, int]]): A mapping from each edge type
            to a tuple of (source_node_type, destination_node_type) [R].
        condensed_node_type_to_vocab_size_map (dict[int, int]): A mapping from each node type to the size of its vocabulary.
        num_negatives_per_condensed_edge_type (int): The number of negative edges to sample per edge type [K].
    Returns:
        negative_edges (Tensor): A tensor of shape [R * K] containing negative edges.
        negative_edge_types (Tensor): A tensor of shape [R * K] containing the edge type
            for each negative edge.
        negative_labels (Tensor): A tensor of zeros with shape [R * K], suitable for
            use in contrastive or classification losses.
    """
    negative_condensed_edge_types = torch.tensor(
        list(condensed_edge_type_to_condensed_node_type_map.keys()), dtype=torch.int32
    ).repeat_interleave(num_negatives_per_condensed_edge_type)
    negative_edges = torch.zeros(
        negative_condensed_edge_types.numel(), 2, dtype=torch.int
    )
    # Labels are all 0 for negatives
    negative_labels = torch.zeros_like(negative_condensed_edge_types, dtype=torch.int)
    if num_negatives_per_condensed_edge_type:
        # Corrupt nodes in-place based on edge type and corruption side
        for (
            condensed_edge_type,
            condensed_node_types,
        ) in condensed_edge_type_to_condensed_node_type_map.items():
            relation_mask = (
                negative_condensed_edge_types == condensed_edge_type
            )  # [E * K]
            src_cnt, dst_cnt = condensed_node_types
            # Sample uniformly from the vocabulary.
            src_vocab_size = condensed_node_type_to_vocab_size_map[src_cnt]
            dst_vocab_size = condensed_node_type_to_vocab_size_map[dst_cnt]
            rand_src_inds = (
                torch.rand(size=(num_negatives_per_condensed_edge_type,))
                * src_vocab_size
            ).to(torch.int)
            rand_dst_inds = (
                torch.rand(size=(num_negatives_per_condensed_edge_type,))
                * dst_vocab_size
            ).to(torch.int)
            negative_edges[relation_mask, 0] = rand_src_inds
            negative_edges[relation_mask, 1] = rand_dst_inds
    return negative_edges, negative_condensed_edge_types, negative_labels