Source code for gigl.src.common.utils.eval_metrics
from typing import cast
import torch
[docs]
def hit_rate_at_k(
pos_scores: torch.FloatTensor, neg_scores: torch.FloatTensor, ks: torch.LongTensor
) -> torch.FloatTensor:
"""Computes Hit Rate @ K metrics for various Ks, evaluating 1+ positives against 1+ negatives.
Args:
pos_scores (torch.FloatTensor): Contains 1 or more positive sample scores.
neg_scores (torch.FloatTensor): Contains 1 or more negative sample scores.
ks (torch.LongTensor): k-values for which to compute hits.
Returns:
torch.FloatTensor: Hit rates corresponding to the requested ks.
"""
max_k_requested = int(torch.max(ks).item())
max_viable_k = 1 + neg_scores.numel()
min_k_requested = torch.min(ks).item()
assert (
min_k_requested >= 1
), f"ks must be greater-or-equal to 1 (got {min_k_requested})"
pos_scores_reshaped = pos_scores.view(-1, 1)
neg_scores_reshaped = neg_scores.view(1, -1)
num_pos_scores = pos_scores_reshaped.shape[0]
neg_scores_repeated = neg_scores_reshaped.repeat(num_pos_scores, 1)
all_scores = torch.hstack((pos_scores_reshaped, neg_scores_repeated))
all_scores_sorted = torch.argsort(all_scores, dim=1, descending=True)
one_hot_scores = all_scores_sorted == 0
hit_indicators = torch.cumsum(one_hot_scores, dim=1)
hit_rates = hit_indicators.float().mean(dim=0)
hit_rates_padded = (
torch.cat(
(
hit_rates,
torch.ones(
size=(max_k_requested - hit_rates.numel(),), device=hit_rates.device
),
)
)
if max_k_requested > max_viable_k
else hit_rates
)
ks_adjusted = ks - 1 # subtract 1 since indices are 0-indexed
hits_at_ks = torch.gather(input=hit_rates_padded, dim=0, index=ks_adjusted)
return cast(torch.FloatTensor, hits_at_ks)
[docs]
def mean_reciprocal_rank(
pos_scores: torch.FloatTensor, neg_scores: torch.FloatTensor
) -> torch.FloatTensor:
"""Computes Mean Reciprocal Rank (MRR), evaluating 1+ positives against 1+ negatives.
Args:
pos_scores (torch.FloatTensor): Contains 1 or more positive sample scores.
neg_scores (torch.FloatTensor): Contains 1 or more negative sample scores.
Returns:
torch.FloatTensor: Computed MRR score.
"""
pos_scores_reshaped = pos_scores.view(-1, 1)
neg_scores_reshaped = neg_scores.view(1, -1)
num_pos_scores = pos_scores_reshaped.shape[0]
neg_scores_repeated = neg_scores_reshaped.repeat(num_pos_scores, 1)
all_scores = torch.hstack((pos_scores_reshaped, neg_scores_repeated))
all_scores_sorted = torch.argsort(all_scores, dim=1, descending=True)
_, unadjusted_ranks = torch.where(all_scores_sorted == 0)
adjusted_ranks = unadjusted_ranks + 1 # +1 since ranks are 0-indexed here
reciprocal_ranks = 1.0 / adjusted_ranks # compute reciprocal
mrr = torch.mean(reciprocal_ranks)
return cast(torch.FloatTensor, mrr)