import itertools
from enum import Enum
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gigl.common.logger import Logger
from gigl.src.common.types.graph_data import CondensedEdgeType
from gigl.src.common.types.task_inputs import BatchCombinedScores, BatchScores
[docs]
class ModelResultType(Enum):
[docs]
    batch_scores = "batch_scores" 
[docs]
    batch_combined_scores = "batch_combined_scores" 
[docs]
    batch_embeddings = "batch_embeddings" 
 
[docs]
class MarginLoss(nn.Module):
    """
    A loss layer built on top of the PyTorch implementation of the margin ranking loss.
    The loss function by default calculates the loss by
        margin_ranking_loss(pos_scores, hard_neg_scores, random_neg_scores, margin=margin, reduction='sum')
    It encourages the model to generate higher similarity scores for positive pairs than negative pairs by at least a margin.
    See: https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html for more information.
    """
    def __init__(
        self,
        margin: Optional[float] = None,
    ):
        super(MarginLoss, self).__init__()
    def _calculate_margin_loss(
        self,
        pos_scores: torch.Tensor,
        hard_neg_scores: torch.Tensor,
        random_neg_scores: torch.Tensor,
        device: torch.device = torch.device("cpu"),
    ) -> Tuple[torch.Tensor, int]:
        all_neg_scores = torch.cat(
            (hard_neg_scores, random_neg_scores),
            dim=1,
        )  # shape=[1, num_hard_neg_nodes + num_random_neg_nodes]
        all_neg_scores_repeated = all_neg_scores.repeat(
            1, pos_scores.shape[1]
        )  # shape=[1, (num_hard_neg_nodes + num_random_neg_nodes) * num_pos_nodes]
        pos_scores_repeated = pos_scores.repeat_interleave(
            all_neg_scores.shape[1], dim=1
        )  # shape=[1, num_pos_nodes * (num_hard_neg_nodes + num_random_neg_nodes)]
        ys = torch.ones_like(pos_scores_repeated).to(
            device=device
        )  # shape=[1, num_pos_nodes * (num_hard_neg_nodes + num_random_neg_nodes)]
        loss = F.margin_ranking_loss(
            input1=pos_scores_repeated,
            input2=all_neg_scores_repeated,
            target=ys,
            margin=self.margin,  # type: ignore
            reduction="sum",
        )
        sample_size = pos_scores_repeated.numel()
        return loss, sample_size
[docs]
    def forward(
        self,
        loss_input: list[dict[CondensedEdgeType, BatchScores]],
        device: torch.device = torch.device("cpu"),
    ) -> Tuple[torch.Tensor, int]:
        batch_loss = torch.tensor(0.0).to(device=device)
        batch_size = 0
        # In case we have an empty list as input, avoids division by zero error
        if not len(loss_input):
            batch_size = 1
        for result_sample in loss_input:
            for condensed_edge_type in result_sample:
                if result_sample[condensed_edge_type].pos_scores.numel():
                    sample_loss, sample_size = self._calculate_margin_loss(
                        pos_scores=result_sample[condensed_edge_type].pos_scores,
                        hard_neg_scores=result_sample[
                            condensed_edge_type
                        ].hard_neg_scores,
                        random_neg_scores=result_sample[
                            condensed_edge_type
                        ].random_neg_scores,
                        device=device,
                    )
                    batch_loss += sample_loss
                    batch_size += sample_size
        return batch_loss, batch_size 
 
[docs]
class SoftmaxLoss(nn.Module):
    """
    A loss layer built on top of the PyTorch implementation of the softmax cross entropy loss.
    The loss function by default calculate the loss by
        cross_entropy(all_scores, ys, reduction='sum')
    See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for more information.
    """
    def __init__(
        self,
        softmax_temperature: Optional[float] = None,
    ):
        super(SoftmaxLoss, self).__init__()
[docs]
        self.softmax_temperature = softmax_temperature 
    def _calculate_softmax_loss(
        self,
        pos_scores: torch.Tensor,
        hard_neg_scores: torch.Tensor,
        random_neg_scores: torch.Tensor,
        device: torch.device,
    ) -> Tuple[torch.Tensor, int]:
        all_neg_scores = torch.cat(
            (hard_neg_scores, random_neg_scores),
            dim=1,
        ).squeeze()  # shape=[num_hard_neg_nodes + num_random_neg_nodes]
        all_neg_scores_repeated = all_neg_scores.repeat(
            pos_scores.shape[1], 1
        )  # shape=[num_pos_nodes, num_hard_neg_nodes + num_random_neg_nodes]
        all_scores = torch.cat(
            (
                pos_scores.reshape(-1, 1),
                all_neg_scores_repeated,
            ),
            dim=1,
        )  # shape=[num_pos_nodes, 1 + num_hard_neg_nodes + num_random_neg_nodes]
        ys = (
            torch.zeros(pos_scores.shape[1]).long().to(device=device)
        )  # shape=[num_pos_nodes]
        loss = F.cross_entropy(
            input=all_scores / self.softmax_temperature,
            target=ys,
            reduction="sum",
        )
        sample_size = pos_scores.shape[1]
        return loss, sample_size
[docs]
    def forward(
        self,
        loss_input: list[dict[CondensedEdgeType, BatchScores]],
        device: torch.device = torch.device("cpu"),
    ) -> Tuple[torch.Tensor, int]:
        batch_loss = torch.tensor(0.0).to(device=device)
        batch_size = 0
        # In case we have an empty list as input, avoids division by zero error
        if not len(loss_input):
            batch_size = 1
        for result_sample in loss_input:
            for condensed_edge_type in result_sample:
                if result_sample[condensed_edge_type].pos_scores.numel():
                    sample_loss, sample_size = self._calculate_softmax_loss(
                        pos_scores=result_sample[condensed_edge_type].pos_scores,
                        hard_neg_scores=result_sample[
                            condensed_edge_type
                        ].hard_neg_scores,
                        random_neg_scores=result_sample[
                            condensed_edge_type
                        ].random_neg_scores,
                        device=device,
                    )
                    batch_loss += sample_loss
                    batch_size += sample_size
        return batch_loss, batch_size 
 
[docs]
class RetrievalLoss(nn.Module):
    """
    A loss layer built on top of the tensorflow_recommenders implementation.
    https://www.tensorflow.org/recommenders/api_docs/python/tfrs/tasks/Retrieval
    The loss function by default calculates the loss by:
    ```
    cross_entropy(torch.mm(query_embeddings, candidate_embeddings.T), positive_indices, reduction='sum'),
    ```
    where the candidate embeddings are `torch.cat((positive_embeddings, random_negative_embeddings))`. It encourages the model to generate query embeddings that yield the highest similarity score with their own first hop compared with others' first hops and random negatives. We also filter out the cases where, in some rows, the query could accidentally treat its own positives as negatives.
    Args:
        loss (Optional[nn.Module]): Custom loss function to be used. If `None`, the default is `nn.CrossEntropyLoss(reduction="sum")`.
        temperature (Optional[float]): Temperature scaling applied to scores before computing cross-entropy loss. If not `None`, scores are divided by the temperature value.
        remove_accidental_hits (bool): Whether to remove accidental hits where the query's positive items are also present in the negative samples.
    """
    def __init__(
        self,
        loss: Optional[nn.Module] = None,
        temperature: Optional[float] = None,
        remove_accidental_hits: bool = False,
    ):
        super(RetrievalLoss, self).__init__()
        self._loss = loss if loss is not None else nn.CrossEntropyLoss(reduction="sum")
        self._temperature = temperature
        if self._temperature is not None and self._temperature < 1e-12:
            raise ValueError(
                f"The temperature is expected to be greater than 1e-12, however you provided {self._temperature}"
            )
        self._remove_accidental_hits = remove_accidental_hits
        logger.warning(
            "Calculating retrieval loss with the class at gigl.src.common.models.layers.loss.RetrievalLoss is deprecated and will be removed in a future release. "
            "Please use the `gigl.module.loss.RetrievalLoss` class instead."
        )
[docs]
    def calculate_batch_retrieval_loss(
        self,
        scores: torch.Tensor,
        candidate_sampling_probability: Optional[torch.Tensor] = None,
        query_ids: Optional[torch.Tensor] = None,
        candidate_ids: Optional[torch.Tensor] = None,
        device: torch.device = torch.device("cpu"),
    ) -> torch.Tensor:
        """
        Args:
          scores: [num_queries, num_candidates] tensor of candidate and query embeddings similarity
          candidate_sampling_probability: [num_candidates], Optional tensor of candidate sampling probabilities.
            When given will be used to correct the logits toreflect the sampling probability of negative candidates.
          query_ids: [num_queries] Optional tensor containing query ids / anchor node ids.
          candidate_ids: [num_candidates] Optional tensor containing candidate ids.
          device: the device to set as default
        """
        num_queries: int = scores.shape[0]
        num_candidates: int = scores.shape[1]
        torch._assert(
            num_queries <= num_candidates,
            "Number of queries should be less than or equal to number of candidates in a batch",
        )
        labels = torch.eye(num_queries, num_candidates).to(
            device=device
        )  # [num_queries, num_candidates]
        duplicates = torch.zeros_like(labels).to(
            device=device
        )  # [num_queries, num_candidates]
        if self._temperature is not None:
            scores = scores / self._temperature
        # provide the corresponding candidate sampling probability to enable sampled softmax
        if candidate_sampling_probability is not None:
            scores = scores - torch.log(
                torch.clamp(
                    candidate_sampling_probability, min=1e-10
                )  # frequency can be used so only limit its lower bound here
            ).type(scores.dtype)
        # obtain a mask that indicates true labels for each query when using multiple positives per query
        if query_ids is not None:
            duplicates = torch.maximum(
                duplicates,
                self._mask_by_query_ids(
                    query_ids, num_queries, num_candidates, labels.dtype, device
                ),
            )  # [num_queries, num_candidates]
        # obtain a mask that indicates true labels for each query when random negatives contain positives in this batch
        if self._remove_accidental_hits:
            if candidate_ids is None:
                raise ValueError(
                    "When accidental hit removal is enabled, candidate ids must be supplied."
                )
            duplicates = torch.maximum(
                duplicates,
                self._mask_by_candidate_ids(
                    candidate_ids, num_queries, labels.dtype, device
                ),
            )  # [num_queries, num_candidates]
        if query_ids is not None or self._remove_accidental_hits:
            # mask out the extra positives in each row by setting their logits to min(scores.dtype)
            scores = scores + (duplicates - labels) * torch.finfo(scores.dtype).min
        return self._loss(scores, target=labels) 
    def _mask_by_query_ids(
        self,
        query_ids: torch.Tensor,
        num_queries: int,
        num_candidates: int,
        dtype: torch.dtype,
        device: torch.device = torch.device("cpu"),
    ) -> torch.Tensor:
        """
        Args:
            query_ids: [num_queries] query ids / anchor node ids in the batch
            num_queries: number of queries / rows in the batch
            num_candidates: number of candidates / columns in the batch
            dtype: labels dtype
            device: the device to set as default
        """
        query_ids = torch.unsqueeze(query_ids, 1)  # [num_queries, 1]
        duplicates = torch.eq(query_ids, query_ids.T).type(
            dtype
        )  # [num_queries, num_queries]
        if num_queries < num_candidates:
            padding_zeros = torch.zeros(
                (num_queries, num_candidates - num_queries), dtype=dtype
            ).to(device=device)
            return torch.cat(
                (duplicates, padding_zeros), dim=1
            )  # [num_queries, num_candidates]
        return duplicates
    def _mask_by_candidate_ids(
        self,
        candidate_ids: torch.Tensor,
        num_queries: int,
        dtype: torch.dtype,
        device: torch.device = torch.device("cpu"),
    ) -> torch.Tensor:
        """
        Args:
            candidate_ids: [num_candidates] candidate ids in this batch
            num_queries: number of queries / rows in the batch
            dtype: labels dtype
            device: the device to set as default
        """
        positive_indices = torch.arange(num_queries).to(device=device)  # [num_queries]
        positive_candidate_ids = torch.gather(
            candidate_ids, 0, positive_indices
        ).unsqueeze(
            1
        )  # [num_queries, 1]
        all_candidate_ids = torch.unsqueeze(candidate_ids, 1)  # [num_candidates, 1]
        return torch.eq(positive_candidate_ids, all_candidate_ids.T).type(
            dtype
        )  # [num_queries, num_candidates]
[docs]
    def forward(
        self,
        batch_combined_scores: BatchCombinedScores,
        repeated_query_embeddings: torch.FloatTensor,
        candidate_sampling_probability: Optional[torch.FloatTensor] = None,
        device: torch.device = torch.device("cpu"),
    ) -> Tuple[torch.Tensor, int]:
        candidate_ids = torch.cat(
            (
                batch_combined_scores.positive_ids.to(device=device),
                batch_combined_scores.hard_neg_ids.to(device=device),
                batch_combined_scores.random_neg_ids.to(device=device),
            )
        )
        if repeated_query_embeddings.numel():  # type: ignore
            loss = self.calculate_batch_retrieval_loss(
                scores=batch_combined_scores.repeated_candidate_scores,
                candidate_sampling_probability=candidate_sampling_probability,
                query_ids=batch_combined_scores.repeated_query_ids,
                candidate_ids=candidate_ids,
                device=device,
            )
            batch_size = repeated_query_embeddings.shape[0]  # type: ignore
        else:
            loss = torch.tensor(0.0).to(device=device)
            batch_size = 1
        return loss, batch_size 
 
[docs]
class GRACELoss(nn.Module):
    """
    A loss class that implements the GRACE (https://arxiv.org/pdf/2006.04131.pdf) contrastive loss approach. We generate two graph views by
    corruption and learn node representations by maximizing the agreement of node representations in these two views. We introduce this to add an
    additional contrastive loss function for multi-task learning.
    """
    def __init__(
        self,
        temperature: Optional[float] = None,
    ):
        super(GRACELoss, self).__init__()
[docs]
        self.temperature = temperature 
[docs]
    def forward(
        self,
        h1: torch.Tensor,
        h2: torch.Tensor,
        device: torch.device = torch.device("cpu"),
    ) -> Tuple[torch.Tensor, int]:
        """
        Args:
            h1 (torch.Tensor): First input tensor
            h2 (torch.Tensor): Second input tensor
            device (torch.device): the device to set as default
        Returns:
            Tuple[torch.Tensor, int]: The loss and the sample size
        """
        def sim_matrix(a: torch.Tensor, b: torch.Tensor, eps=1e-8) -> torch.Tensor:
            """
            Computes similarity between two vectors 'a' and 'b' by normalizing vectors before creating a cosine similarity matrix.
            """
            a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
            a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
            b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
            sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)).to(device=device)
            return sim_mt
        def get_loss(h1: torch.Tensor, h2: torch.Tensor) -> torch.Tensor:
            """
            Uses cosine similarity matrices between intra-vew pairs and inter-view pairs to generate loss
            """
            f = lambda x: torch.exp(x / self.temperature)
            refl_sim = f(sim_matrix(h1, h1))  # intra-view pairs
            between_sim = f(sim_matrix(h1, h2))  # inter-view pairs
            x1 = refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()
            loss = -torch.log(between_sim.diag() / x1)
            return loss
        l1 = get_loss(h1, h2)
        l2 = get_loss(h2, h1)
        ret = (l1 + l2) * 0.5
        return ret.mean(), 1 
 
[docs]
class FeatureReconstructionLoss(nn.Module):
    """
    Computes SCE between original feature and reconstructed feature. See https://arxiv.org/pdf/2205.10803.pdf for more information about
    feature reconstruction. We use this as an auxiliary loss for training and improved generalization.
    """
    def __init__(
        self,
        alpha: float = 3.0,
    ):
        super(FeatureReconstructionLoss, self).__init__()
[docs]
    def forward(
        self,
        x_target: torch.Tensor,
        x_pred: torch.Tensor,
    ) -> Tuple[torch.Tensor, int]:
        x = F.normalize(x_target, p=2, dim=-1)  # SCE Loss Computation
        y = F.normalize(x_pred, p=2, dim=-1)
        loss = (1 - (x * y).sum(dim=-1)).pow_(self.alpha)
        loss = loss.mean()
        return loss, 1 
 
[docs]
class WhiteningDecorrelationLoss(nn.Module):
    """
    Utilizes canonical correlation analysis to compute similarity between augmented graphs as an auxiliary loss. See https://arxiv.org/pdf/2106.12484.pdf
    for more information.
    """
    def __init__(
        self,
        lambd: float = 1e-3,
    ):
        super(WhiteningDecorrelationLoss, self).__init__()
[docs]
    def forward(
        self,
        h1: torch.Tensor,
        h2: torch.Tensor,
        N: int,
        device: torch.device = torch.device("cpu"),
    ) -> Tuple[torch.Tensor, int]:
        """
        Args:
            h1 (torch.Tensor): First input tensor
            h2 (torch.Tensor): Second input tensor
            N (int): The number of samples
            device (torch.device): the device to set as default
        Returns:
            Tuple[torch.Tensor, int]: The loss and the sample size
        """
        z1 = (h1 - h1.mean(0)) / h1.std(0)
        z2 = (h2 - h2.mean(0)) / h2.std(0)
        c1 = torch.mm(z1.T, z1)
        c2 = torch.mm(z2.T, z2)
        c = (z1 - z2) / N
        c1 = c1 / N
        c2 = c2 / N
        loss_inv = torch.linalg.matrix_norm(c)
        iden = torch.tensor(np.eye(c1.shape[0])).to(device=device)
        loss_dec1 = torch.linalg.matrix_norm(iden - c1)
        loss_dec2 = torch.linalg.matrix_norm(iden - c2)
        return loss_inv + self.lambd * (loss_dec1 + loss_dec2), 1 
 
[docs]
class GBTLoss(nn.Module):
    """
    Computes the Barlow Twins loss on the two input matrices as an auxiliary loss.
    From the offical GBT implementation at:
    https://github.com/pbielak/graph-barlow-twins/blob/ec62580aa89bf3f0d20c92e7549031deedc105ab/gssl/loss.py
    """
    def __init__(
        self,
    ):
        super(GBTLoss, self).__init__()
[docs]
    def forward(
        self,
        z_a: torch.Tensor,
        z_b: torch.Tensor,
        device: torch.device,
    ) -> Tuple[torch.Tensor, int]:
        """
        Args:
            z_a (torch.Tensor): First input matrix
            z_b (torch.Tensor): Second input matrix
            device (torch.device): the device to set as default
        Returns:
            Tuple[torch.Tensor, int]: The Barlow Twins loss and the sample size
        """
        batch_size = z_a.size(0)
        feature_dim = z_a.size(1)
        _lambda = 1 / feature_dim
        # Apply batch normalization
        z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + self.eps)
        z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + self.eps)
        # Cross-correlation matrix
        c = (z_a_norm.T @ z_b_norm) / batch_size
        # Loss function
        off_diagonal_mask = ~torch.eye(feature_dim).bool().to(device=device)
        loss = (1 - c.diagonal()).pow(2).sum() + _lambda * c[off_diagonal_mask].pow(
            2
        ).sum()
        return loss, 1 
 
[docs]
class BGRLLoss(nn.Module):
    """
    Leverages BGRL loss from https://arxiv.org/pdf/2102.06514.pdf, using an offline and online encoder to predict alternative augmentations of
    the input. The offline encoder is updated by an exponential moving average rather than traditional backpropogation. We use BGRL as an
    auxiliary loss for improved generalization.
    """
[docs]
    def forward(
        self,
        q1: torch.Tensor,
        q2: torch.Tensor,
        y1: torch.Tensor,
        y2: torch.Tensor,
    ) -> Tuple[torch.Tensor, int]:
        loss = (
            2
            - F.cosine_similarity(q1, y2.detach(), dim=-1).mean()
            - F.cosine_similarity(q2, y1.detach(), dim=-1).mean()
        )
        return loss, 1 
 
[docs]
class TBGRLLoss(nn.Module):
    """
    TBGRL (https://arxiv.org/pdf/2211.14394.pdf) improves over BGRL by generating a third augmented graph as a negative sample,
    providing a cheap corruption that improves generalizability of the model in inductive settings. We use TBGRL as an auxiliary loss
    for improved generalization.
    """
    def __init__(
        self,
        neg_lambda: float = 0.12,
    ):
        super(TBGRLLoss, self).__init__()
[docs]
        self.neg_lambda = neg_lambda 
[docs]
    def forward(
        self,
        q1: torch.Tensor,
        q2: torch.Tensor,
        y1: torch.Tensor,
        y2: torch.Tensor,
        neg_y: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, int]:
        sim1 = F.cosine_similarity(q1, y2.detach()).mean()
        sim2 = F.cosine_similarity(q2, y1.detach()).mean()
        neg_sim1 = F.cosine_similarity(q1, neg_y.detach()).mean()  # type: ignore
        neg_sim2 = F.cosine_similarity(q2, neg_y.detach()).mean()  # type: ignore
        loss = self.neg_lambda * (neg_sim1 + neg_sim2) - (1 - self.neg_lambda) * (  # type: ignore
            sim1 + sim2
        )
        return loss, 1 
 
[docs]
class AligmentLoss(nn.Module):
    """
    Taken from https://github.com/THUwangcy/DirectAU, AlignmentLoss increases the similarity of representations between positive user-item pairs.
    """
    def __init__(
        self,
        alpha: Optional[float] = 2.0,  # Should not tune this parameter
    ):
        super(AligmentLoss, self).__init__()
[docs]
    def forward(
        self, user_embeddings: torch.Tensor, item_embeddings: torch.Tensor
    ) -> torch.Tensor:
        return (user_embeddings - item_embeddings).norm(p=2, dim=1).pow(self.alpha).mean()  # type: ignore 
 
# TODO Add Unit test for this loss
[docs]
class KLLoss(nn.Module):
    """
    Calculates KL Divergence between two set of scores for the distribution loss.
    Taken from: https://github.com/snap-research/linkless-link-prediction/blob/main/src/main.py
    """
    def __init__(
        self,
        kl_temperature: float,
    ):
        super(KLLoss, self).__init__()
[docs]
        self.kl_temperature = kl_temperature 
[docs]
    def forward(
        self,
        student_scores: torch.Tensor,
        teacher_scores: torch.Tensor,
    ) -> torch.Tensor:
        y_s = F.log_softmax(student_scores / self.kl_temperature, dim=-1)
        y_t = F.softmax(teacher_scores / self.kl_temperature, dim=-1)
        loss = (
            F.kl_div(y_s, y_t, size_average=False)
            * (self.kl_temperature**2)
            / y_s.size()[0]
        )
        return loss 
 
# TODO Add Unit test for this loss
[docs]
class LLPRankingLoss(nn.Module):
    """
    Calculates a margin-based rakning loss between two set of scores for the ranking loss in LLP.
    This differs from normal margin loss in that it prevents the student model from trying to
    differentiate miniscule differences in probabilities which the teacher may make w/ due to noise.
    Taken from: https://github.com/snap-research/linkless-link-prediction/blob/main/src/main.py
    """
    def __init__(
        self,
        margin: float,
    ):
        super(LLPRankingLoss, self).__init__()
[docs]
        self.margin_loss = nn.MarginRankingLoss(margin=margin) 
[docs]
    def forward(
        self,
        student_scores: torch.Tensor,
        teacher_scores: torch.Tensor,
        device: torch.device,
    ) -> torch.Tensor:
        dim_pairs = [
            x for x in itertools.combinations(range(student_scores.shape[1]), r=2)
        ]
        pair_array = np.array(dim_pairs).T
        teacher_rank_list = torch.zeros((len(teacher_scores), pair_array.shape[1])).to(
            device
        )
        mask = teacher_scores[:, pair_array[0]] > (
            teacher_scores[:, pair_array[1]] + self.margin
        )
        teacher_rank_list[mask] = 1
        mask2 = teacher_scores[:, pair_array[0]] < (
            teacher_scores[:, pair_array[1]] - self.margin
        )
        teacher_rank_list[mask2] = -1
        first_rank_list = student_scores[:, pair_array[0]].squeeze()
        second_rank_list = student_scores[:, pair_array[1]].squeeze()
        return self.margin_loss(first_rank_list, second_rank_list, teacher_rank_list)