gigl.module.loss#

Classes#

RetrievalLoss

A loss layer built on top of the tensorflow_recommenders implementation.

Module Contents#

class gigl.module.loss.RetrievalLoss(loss=None, temperature=None, remove_accidental_hits=False)[source]#

Bases: torch.nn.Module

A loss layer built on top of the tensorflow_recommenders implementation. https://www.tensorflow.org/recommenders/api_docs/python/tfrs/tasks/Retrieval

The loss function by default calculates the loss by: ` cross_entropy(torch.mm(query_embeddings, candidate_embeddings.T), positive_indices, reduction='sum'), ` where the candidate embeddings are torch.cat((positive_embeddings, random_negative_embeddings)). It encourages the model to generate query embeddings that yield the highest similarity score with their own first hop compared with others’ first hops and random negatives. We also filter out the cases where, in some rows, the query could accidentally treat its own positives as negatives.

Parameters:
  • loss (Optional[nn.Module]) – Custom loss function to be used. If None, the default is nn.CrossEntropyLoss(reduction=”sum”).

  • temperature (Optional[float]) – Temperature scaling applied to scores before computing cross-entropy loss. If not None, scores are divided by the temperature value.

  • remove_accidental_hits (bool) – Whether to remove accidental hits where the query’s positive items are also present in the negative samples.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(repeated_candidate_scores, candidate_ids, repeated_query_ids, device, candidate_sampling_probability=None)[source]#
Parameters:
  • repeated_candidate_scores (torch.Tensor) – The prediction scores between each repeated query users and each candidates. In this case, repeated means that we repeat each query user based on the number of positive labels they have. Tensor shape: [num_positives, num_positives + num_hard_negatives + num_random_negatives]

  • candidate_ids (torch.Tensor) – Concatenated Ids of the candidates. Tensor shape: [num_positives + num_hard_negatives + num_random_negatives]

  • repeated_query_ids (torch.Tensor) – Repeated query user IDs. Tensor shape: [num_positives]

  • candidate_sampling_probability (Optional[torch.Tensor]) – Optional tensor of candidate sampling probabilities. When given will be used to correct the logits to reflect the sampling probability of negative candidates. Tensor shape: [num_positives + num_hard_negatives + num_random_negatives]

  • device (torch.device)