Source code for gigl.src.common.utils.eval_metrics

from typing import cast

import torch


[docs] def hit_rate_at_k( pos_scores: torch.FloatTensor, neg_scores: torch.FloatTensor, ks: torch.LongTensor ) -> torch.FloatTensor: """Computes Hit Rate @ K metrics for various Ks, evaluating 1+ positives against 1+ negatives. Args: pos_scores (torch.FloatTensor): Contains 1 or more positive sample scores. neg_scores (torch.FloatTensor): Contains 1 or more negative sample scores. ks (torch.LongTensor): k-values for which to compute hits. Returns: torch.FloatTensor: Hit rates corresponding to the requested ks. """ max_k_requested = int(torch.max(ks).item()) max_viable_k = 1 + neg_scores.numel() min_k_requested = torch.min(ks).item() assert ( min_k_requested >= 1 ), f"ks must be greater-or-equal to 1 (got {min_k_requested})" pos_scores_reshaped = pos_scores.view(-1, 1) neg_scores_reshaped = neg_scores.view(1, -1) num_pos_scores = pos_scores_reshaped.shape[0] neg_scores_repeated = neg_scores_reshaped.repeat(num_pos_scores, 1) all_scores = torch.hstack((pos_scores_reshaped, neg_scores_repeated)) all_scores_sorted = torch.argsort(all_scores, dim=1, descending=True) one_hot_scores = all_scores_sorted == 0 hit_indicators = torch.cumsum(one_hot_scores, dim=1) hit_rates = hit_indicators.float().mean(dim=0) hit_rates_padded = ( torch.cat( ( hit_rates, torch.ones( size=(max_k_requested - hit_rates.numel(),), device=hit_rates.device ), ) ) if max_k_requested > max_viable_k else hit_rates ) ks_adjusted = ks - 1 # subtract 1 since indices are 0-indexed hits_at_ks = torch.gather(input=hit_rates_padded, dim=0, index=ks_adjusted) return cast(torch.FloatTensor, hits_at_ks)
[docs] def mean_reciprocal_rank( pos_scores: torch.FloatTensor, neg_scores: torch.FloatTensor ) -> torch.FloatTensor: """Computes Mean Reciprocal Rank (MRR), evaluating 1+ positives against 1+ negatives. Args: pos_scores (torch.FloatTensor): Contains 1 or more positive sample scores. neg_scores (torch.FloatTensor): Contains 1 or more negative sample scores. Returns: torch.FloatTensor: Computed MRR score. """ pos_scores_reshaped = pos_scores.view(-1, 1) neg_scores_reshaped = neg_scores.view(1, -1) num_pos_scores = pos_scores_reshaped.shape[0] neg_scores_repeated = neg_scores_reshaped.repeat(num_pos_scores, 1) all_scores = torch.hstack((pos_scores_reshaped, neg_scores_repeated)) all_scores_sorted = torch.argsort(all_scores, dim=1, descending=True) _, unadjusted_ranks = torch.where(all_scores_sorted == 0) adjusted_ranks = unadjusted_ranks + 1 # +1 since ranks are 0-indexed here reciprocal_ranks = 1.0 / adjusted_ranks # compute reciprocal mrr = torch.mean(reciprocal_ranks) return cast(torch.FloatTensor, mrr)