Source code for gigl.experimental.knowledge_graph_embedding.lib.model.loss_utils

from typing import Union

import torch
import torch.nn.functional as F

# TODO(nshah-sc): Some of these functions don't require labels, and we should refactor them to not require them.
# The current implementation is parameterized as such to support some existing code.


[docs] def bpr_loss( scores: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: """ Computes Bayesian Personalized Ranking (BPR) loss. For each positive score s⁺ and its negatives s⁻₁, ..., s⁻_K, we compute: $$ \mathcal{L}_\\text{BPR} = - \\frac{1}{BK} \sum_{i=1}^B \sum_{j=1}^K \log \sigma(s_i^+ - s_{ij}^-) $$ Args: scores: Score tensor of shape [B, 1 + K], where B is the batch size and K is the number of negatives. 1st column contains positive scores, and the rest are negative scores. labels: Label tensor of shape [B, 1 + K], where 1 indicates positive and 0 indicates negative. 1st column contains positive labels, and the rest are negative labels. Returns: Scalar BPR loss """ pos = scores[:, 0].unsqueeze(1) # (B, 1) neg = scores[:, 1:] # (B, K) diff = pos - neg # (B, K) loss = -F.logsigmoid(diff).mean() # scalar return loss
[docs] def infonce_loss( scores: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0, ) -> torch.Tensor: """ Computes InfoNCE contrastive loss. We treat each group of (1 positive + K negatives) as a (1 + K)-way classification: $$ \mathcal{L}_\\text{InfoNCE} = - \\frac{1}{B} \sum_{i=1}^B \log \frac{\exp(s_i^+ / \\tau)}{\sum_{j=0}^{K} \exp(s_{ij} / \\tau)} $$ Args: scores: Score tensor of shape [B, 1 + K], where B is the batch size and K is the number of negatives. 1st column contains positive scores, and the rest are negative scores. labels: Label tensor of shape [B, 1 + K], where 1 indicates positive and 0 indicates negative. 1st column contains positive labels, and the rest are negative labels. num_negatives: K, number of negatives per positive Returns: Scalar InfoNCE loss """ scores = scores / temperature # (B, 1 + K) loss = F.cross_entropy( scores, torch.zeros(scores.size(0), dtype=torch.long, device=scores.device) ) return loss
[docs] def average_pos_neg_scores( scores: torch.Tensor, labels: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Computes the average positive and negative scores from scores and labels. Args: scores: Score tensor. labels: Label tensor of corresponding shape. 1s indicate positive, 0s indicate negative. Returns: tuple[torch.Tensor, torch.Tensor]: (avg_pos_score, avg_neg_score). These are scalars, and the tensors are detached from the computation graph since we don't want to backprop. """ one_mask = labels == 1 avg_pos_score = scores[one_mask].mean().detach() avg_neg_score = scores[~one_mask].mean().detach() return avg_pos_score, avg_neg_score
[docs] def hit_rate_at_k( scores: torch.Tensor, labels: torch.Tensor, ks: Union[int, list[int]], ) -> torch.Tensor: """ Computes HitRate@K using pure tensor operations. HitRate@K is defined as: \[ \text{HitRate@K} = \frac{1}{N} \sum_{i=1}^N \mathbb{1}\{ \text{positive in top-K} \} \] Args: scores: Score tensor of shape [B, 1 + K], where B is the batch size and K is the number of negatives. 1st column contains positive scores, and the rest are negative scores. labels: Label tensor of shape [B, 1 + K], where 1 indicates positive and 0 indicates negative. 1st column contains positive labels, and the rest are negative labels. ks: An integer or list of integers indicating K values. Maximum K should be less than or equal to the number of negatives + 1. Returns: A tensor (if one K) or dict of tensors (if multiple Ks), each giving HitRate@K. """ if isinstance(ks, int): ks = [ks] ks_tensor = torch.tensor(sorted(set(ks)), device=scores.device) # Get top max_k indices (shape B x max_k) max_k = int(ks_tensor.max().item()) topk_indices = torch.topk(scores, k=max_k, dim=1).indices # shape: (B, max_k) # Gather corresponding labels for top-k entries topk_labels = torch.gather(labels, dim=1, index=topk_indices) # shape: (B, max_k) # For each k, compute hit (positive appeared in top-k) hits_at_k = (topk_labels.cumsum(dim=1) > 0).float() # (B, max_k) hit_rates = hits_at_k[:, ks_tensor - 1].mean(dim=0) # (len(ks),) return hit_rates
[docs] def mean_reciprocal_rank( scores: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: """ Computes Mean Reciprocal Rank (MRR) using pure tensor operations. MRR is defined as: \[ \text{MRR} = \frac{1}{N} \sum_{i=1}^N \frac{1}{\text{rank}_i} \] where rank_i is the 1-based index of the positive in the sorted list. Args: scores: Score tensor of shape [B, 1 + K], where B is the batch size and K is the number of negatives. 1st column contains positive scores, and the rest are negative scores. labels: Label tensor of shape [B, 1 + K], where 1 indicates positive and 0 indicates negative. 1st column contains positive labels, and the rest are negative labels. Returns: Scalar tensor with MRR. """ # Sort scores descending, get sort indices sorted_indices = torch.argsort(scores, dim=1, descending=True) # Use sort indices to reorder labels sorted_labels = torch.gather(labels, dim=1, index=sorted_indices) # Find the index of the positive label (label == 1) in sorted list reciprocal_ranks = 1.0 / (torch.argmax(sorted_labels, dim=1).float() + 1.0) # (B,) return reciprocal_ranks.mean()