Source code for gigl.module.loss

from typing import Optional

import torch
import torch.nn as nn


[docs] class RetrievalLoss(nn.Module): """ A loss layer built on top of the tensorflow_recommenders implementation. https://www.tensorflow.org/recommenders/api_docs/python/tfrs/tasks/Retrieval The loss function by default calculates the loss by: ``` cross_entropy(torch.mm(query_embeddings, candidate_embeddings.T), positive_indices, reduction='sum'), ``` where the candidate embeddings are `torch.cat((positive_embeddings, random_negative_embeddings))`. It encourages the model to generate query embeddings that yield the highest similarity score with their own first hop compared with others' first hops and random negatives. We also filter out the cases where, in some rows, the query could accidentally treat its own positives as negatives. Args: loss (Optional[nn.Module]): Custom loss function to be used. If `None`, the default is `nn.CrossEntropyLoss(reduction="sum")`. temperature (Optional[float]): Temperature scaling applied to scores before computing cross-entropy loss. If not `None`, scores are divided by the temperature value. remove_accidental_hits (bool): Whether to remove accidental hits where the query's positive items are also present in the negative samples. """ def __init__( self, loss: Optional[nn.Module] = None, temperature: Optional[float] = None, remove_accidental_hits: bool = False, ): super(RetrievalLoss, self).__init__() self._loss = loss if loss is not None else nn.CrossEntropyLoss(reduction="sum") self._temperature = temperature if self._temperature is not None and self._temperature < 1e-12: raise ValueError( f"The temperature is expected to be greater than 1e-12, however you provided {self._temperature}" ) self._remove_accidental_hits = remove_accidental_hits def _calculate_batch_retrieval_loss( self, scores: torch.Tensor, candidate_sampling_probability: Optional[torch.Tensor] = None, query_ids: Optional[torch.Tensor] = None, candidate_ids: Optional[torch.Tensor] = None, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """ Args: scores: [num_queries, num_candidates] tensor of candidate and query embeddings similarity candidate_sampling_probability: [num_candidates], Optional tensor of candidate sampling probabilities. When given will be used to correct the logits toreflect the sampling probability of negative candidates. query_ids: [num_queries] Optional tensor containing query ids / anchor node ids. candidate_ids: [num_candidates] Optional tensor containing candidate ids. device: the device to set as default """ num_queries: int = scores.shape[0] num_candidates: int = scores.shape[1] torch._assert( num_queries <= num_candidates, "Number of queries should be less than or equal to number of candidates in a batch", ) labels = torch.eye(num_queries, num_candidates).to( device=device ) # [num_queries, num_candidates] duplicates = torch.zeros_like(labels).to( device=device ) # [num_queries, num_candidates] if self._temperature is not None: scores = scores / self._temperature # provide the corresponding candidate sampling probability to enable sampled softmax if candidate_sampling_probability is not None: scores = scores - torch.log( torch.clamp( candidate_sampling_probability, min=1e-10 ) # frequency can be used so only limit its lower bound here ).type(scores.dtype) # obtain a mask that indicates true labels for each query when using multiple positives per query if query_ids is not None: duplicates = torch.maximum( duplicates, self._mask_by_query_ids( query_ids, num_queries, num_candidates, labels.dtype, device ), ) # [num_queries, num_candidates] # obtain a mask that indicates true labels for each query when random negatives contain positives in this batch if self._remove_accidental_hits: if candidate_ids is None: raise ValueError( "When accidental hit removal is enabled, candidate ids must be supplied." ) duplicates = torch.maximum( duplicates, self._mask_by_candidate_ids( candidate_ids, num_queries, labels.dtype, device ), ) # [num_queries, num_candidates] if query_ids is not None or self._remove_accidental_hits: # mask out the extra positives in each row by setting their logits to min(scores.dtype) scores = scores + (duplicates - labels) * torch.finfo(scores.dtype).min return self._loss(scores, target=labels) def _mask_by_query_ids( self, query_ids: torch.Tensor, num_queries: int, num_candidates: int, dtype: torch.dtype, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """ Args: query_ids: [num_queries] query ids / anchor node ids in the batch num_queries: number of queries / rows in the batch num_candidates: number of candidates / columns in the batch dtype: labels dtype device: the device to set as default """ query_ids = torch.unsqueeze(query_ids, 1) # [num_queries, 1] duplicates = torch.eq(query_ids, query_ids.T).type( dtype ) # [num_queries, num_queries] if num_queries < num_candidates: padding_zeros = torch.zeros( (num_queries, num_candidates - num_queries), dtype=dtype ).to(device=device) return torch.cat( (duplicates, padding_zeros), dim=1 ) # [num_queries, num_candidates] return duplicates def _mask_by_candidate_ids( self, candidate_ids: torch.Tensor, num_queries: int, dtype: torch.dtype, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """ Args: candidate_ids: [num_candidates] candidate ids in this batch num_queries: number of queries / rows in the batch dtype: labels dtype device: the device to set as default """ positive_indices = torch.arange(num_queries).to(device=device) # [num_queries] positive_candidate_ids = torch.gather( candidate_ids, 0, positive_indices ).unsqueeze( 1 ) # [num_queries, 1] all_candidate_ids = torch.unsqueeze(candidate_ids, 1) # [num_candidates, 1] return torch.eq(positive_candidate_ids, all_candidate_ids.T).type( dtype ) # [num_queries, num_candidates]
[docs] def forward( self, repeated_candidate_scores: torch.Tensor, candidate_ids: torch.Tensor, repeated_query_ids: torch.Tensor, device: torch.device, candidate_sampling_probability: Optional[torch.Tensor] = None, ): """ Args: repeated_candidate_scores (torch.Tensor): The prediction scores between each repeated query users and each candidates. In this case, `repeated` means that we repeat each query user based on the number of positive labels they have. Tensor shape: [num_positives, num_positives + num_hard_negatives + num_random_negatives] candidate_ids (torch.Tensor): Concatenated Ids of the candidates. Tensor shape: [num_positives + num_hard_negatives + num_random_negatives] repeated_query_ids (torch.Tensor): Repeated query user IDs. Tensor shape: [num_positives] candidate_sampling_probability (Optional[torch.Tensor]): Optional tensor of candidate sampling probabilities. When given will be used to correct the logits to reflect the sampling probability of negative candidates. Tensor shape: [num_positives + num_hard_negatives + num_random_negatives] """ loss = self._calculate_batch_retrieval_loss( scores=repeated_candidate_scores, candidate_sampling_probability=candidate_sampling_probability, query_ids=repeated_query_ids, candidate_ids=candidate_ids, device=device, ) return loss