from typing import Union
import torch
import torch.nn.functional as F
# TODO(nshah-sc): Some of these functions don't require labels, and we should refactor them to not require them.
# The current implementation is parameterized as such to support some existing code.
[docs]
def bpr_loss(
scores: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Computes Bayesian Personalized Ranking (BPR) loss.
For each positive score s⁺ and its negatives s⁻₁, ..., s⁻_K, we compute:
$$
\mathcal{L}_\\text{BPR} = - \\frac{1}{BK} \sum_{i=1}^B \sum_{j=1}^K \log \sigma(s_i^+ - s_{ij}^-)
$$
Args:
scores: Score tensor of shape [B, 1 + K], where B is the batch size and K is the number of negatives.
1st column contains positive scores, and the rest are negative scores.
labels: Label tensor of shape [B, 1 + K], where 1 indicates positive and 0 indicates negative.
1st column contains positive labels, and the rest are negative labels.
Returns:
Scalar BPR loss
"""
pos = scores[:, 0].unsqueeze(1) # (B, 1)
neg = scores[:, 1:] # (B, K)
diff = pos - neg # (B, K)
loss = -F.logsigmoid(diff).mean() # scalar
return loss
[docs]
def infonce_loss(
scores: torch.Tensor,
labels: torch.Tensor,
temperature: float = 1.0,
) -> torch.Tensor:
"""
Computes InfoNCE contrastive loss.
We treat each group of (1 positive + K negatives) as a (1 + K)-way classification:
$$
\mathcal{L}_\\text{InfoNCE} = - \\frac{1}{B} \sum_{i=1}^B \log \frac{\exp(s_i^+ / \\tau)}{\sum_{j=0}^{K} \exp(s_{ij} / \\tau)}
$$
Args:
scores: Score tensor of shape [B, 1 + K], where B is the batch size and K is the number of negatives.
1st column contains positive scores, and the rest are negative scores.
labels: Label tensor of shape [B, 1 + K], where 1 indicates positive and 0 indicates negative.
1st column contains positive labels, and the rest are negative labels.
num_negatives: K, number of negatives per positive
Returns:
Scalar InfoNCE loss
"""
scores = scores / temperature # (B, 1 + K)
loss = F.cross_entropy(
scores, torch.zeros(scores.size(0), dtype=torch.long, device=scores.device)
)
return loss
[docs]
def average_pos_neg_scores(
scores: torch.Tensor, labels: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Computes the average positive and negative scores from scores and labels.
Args:
scores: Score tensor.
labels: Label tensor of corresponding shape. 1s indicate positive, 0s indicate negative.
Returns:
tuple[torch.Tensor, torch.Tensor]: (avg_pos_score, avg_neg_score). These are scalars,
and the tensors are detached from the computation graph since we don't want to backprop.
"""
one_mask = labels == 1
avg_pos_score = scores[one_mask].mean().detach()
avg_neg_score = scores[~one_mask].mean().detach()
return avg_pos_score, avg_neg_score
[docs]
def hit_rate_at_k(
scores: torch.Tensor,
labels: torch.Tensor,
ks: Union[int, list[int]],
) -> torch.Tensor:
"""
Computes HitRate@K using pure tensor operations.
HitRate@K is defined as:
\[
\text{HitRate@K} = \frac{1}{N} \sum_{i=1}^N \mathbb{1}\{ \text{positive in top-K} \}
\]
Args:
scores: Score tensor of shape [B, 1 + K], where B is the batch size and K is the number of negatives.
1st column contains positive scores, and the rest are negative scores.
labels: Label tensor of shape [B, 1 + K], where 1 indicates positive and 0 indicates negative.
1st column contains positive labels, and the rest are negative labels.
ks: An integer or list of integers indicating K values. Maximum K should be less than or equal
to the number of negatives + 1.
Returns:
A tensor (if one K) or dict of tensors (if multiple Ks), each giving HitRate@K.
"""
if isinstance(ks, int):
ks = [ks]
ks_tensor = torch.tensor(sorted(set(ks)), device=scores.device)
# Get top max_k indices (shape B x max_k)
max_k = int(ks_tensor.max().item())
topk_indices = torch.topk(scores, k=max_k, dim=1).indices # shape: (B, max_k)
# Gather corresponding labels for top-k entries
topk_labels = torch.gather(labels, dim=1, index=topk_indices) # shape: (B, max_k)
# For each k, compute hit (positive appeared in top-k)
hits_at_k = (topk_labels.cumsum(dim=1) > 0).float() # (B, max_k)
hit_rates = hits_at_k[:, ks_tensor - 1].mean(dim=0) # (len(ks),)
return hit_rates
[docs]
def mean_reciprocal_rank(
scores: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Computes Mean Reciprocal Rank (MRR) using pure tensor operations.
MRR is defined as:
\[
\text{MRR} = \frac{1}{N} \sum_{i=1}^N \frac{1}{\text{rank}_i}
\]
where rank_i is the 1-based index of the positive in the sorted list.
Args:
scores: Score tensor of shape [B, 1 + K], where B is the batch size and K is the number of negatives.
1st column contains positive scores, and the rest are negative scores.
labels: Label tensor of shape [B, 1 + K], where 1 indicates positive and 0 indicates negative.
1st column contains positive labels, and the rest are negative labels.
Returns:
Scalar tensor with MRR.
"""
# Sort scores descending, get sort indices
sorted_indices = torch.argsort(scores, dim=1, descending=True)
# Use sort indices to reorder labels
sorted_labels = torch.gather(labels, dim=1, index=sorted_indices)
# Find the index of the positive label (label == 1) in sorted list
reciprocal_ranks = 1.0 / (torch.argmax(sorted_labels, dim=1).float() + 1.0) # (B,)
return reciprocal_ranks.mean()