Source code for gigl.experimental.knowledge_graph_embedding.lib.model.negative_sampling

from typing import Optional

import torch

from gigl.experimental.knowledge_graph_embedding.lib.model.types import (
    NegativeSamplingCorruptionType,
    SimilarityType,
)


[docs] def in_batch_relationwise_contrastive_similarity( src_embeddings: torch.Tensor, # [B, D] dst_embeddings: torch.Tensor, # [B, D] condensed_edge_types: torch.Tensor, # [B] temperature: float = 1.0, scoring_function: SimilarityType = SimilarityType.COSINE, corrupt_side: NegativeSamplingCorruptionType = NegativeSamplingCorruptionType.DST, num_negatives: Optional[int] = None, # Number of negatives to sample per instance ) -> tuple[torch.Tensor, torch.Tensor]: """ Computes relation-aware in-batch contrastive similarity for knowledge graph embedding training. This function implements contrastive learning for knowledge graph embeddings, where the goal is to learn node representations such that related nodes (connected by edges) have similar embeddings while unrelated nodes have dissimilar embeddings. **Background:** In knowledge graph embedding, we represent entities (nodes) and relations (edge types) as dense vectors. For a triple (source, relation, destination), we want the embeddings of source and destination to be similar when they're connected by that relation type. **Contrastive Learning:** This function uses "in-batch" contrastive learning, meaning it creates negative examples by pairing each positive example with other examples in the same batch that have the same relation type but different source/destination nodes. This is computationally efficient since it reuses embeddings already computed for the batch. **Relation-aware:** The "relation-aware" aspect means that negative examples are only created within the same relation type. For example, if we have a positive example (person_A, "lives_in", city_B), we only create negatives with other "lives_in" relations, not with "works_for" relations. Args: src_embeddings (torch.Tensor): [B, D] Source node embeddings for each positive example. B is the batch size, D is the embedding dimension. Each row represents the embedding of a source entity in a knowledge graph triple. dst_embeddings (torch.Tensor): [B, D] Destination node embeddings for each positive example. Each row represents the embedding of a destination entity that should be similar to its corresponding source entity. condensed_edge_types (torch.Tensor): [B] Integer relation type IDs for each positive example. This identifies which relation type connects each source-destination pair. Examples within the same relation type can be used as negatives for each other. temperature (float, optional): Temperature parameter for scaling similarities before softmax. Lower values (< 1.0) make the model more confident in its predictions by sharpening the probability distribution. Higher values (> 1.0) make predictions more uniform. Defaults to 1.0. scoring_function (SimilarityType, optional): Function used to compute similarity between embeddings. Options are: - COSINE: Cosine similarity (angle between vectors, normalized) - DOT: Dot product (unnormalized, sensitive to magnitude) - EUCLIDEAN: Negative squared Euclidean distance (closer = more similar) Defaults to SimilarityType.COSINE. corrupt_side (NegativeSamplingCorruptionType, optional): Which side of the triple to corrupt when creating negative examples: - DST: Replace destination nodes (e.g., (person_A, "lives_in", wrong_city)) - SRC: Replace source nodes (e.g., (wrong_person, "lives_in", city_B)) - BOTH: Randomly choose to corrupt either source or destination for each example Defaults to NegativeSamplingCorruptionType.DST. num_negatives (Optional[int], optional): Number of negative examples to sample per positive example. If None, uses all valid negatives in the batch (can be computationally expensive for large batches). Setting a specific number (e.g., 10) makes training more efficient. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: logits (torch.Tensor): [B, 1 + K] Similarity scores for positive and negative pairs. The first column contains similarities for the true positive pairs. The remaining K columns contain similarities for the negative pairs. Higher values indicate higher similarity. labels (torch.Tensor): [B, 1 + K] Binary labels corresponding to the logits. The first column is all 1s (indicating positive pairs), and the remaining columns are all 0s (indicating negative pairs). Used for computing contrastive loss. Example: >>> # Batch of 3 examples with 2D embeddings >>> src_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) >>> dst_emb = torch.tensor([[1.0, 0.1], [0.1, 1.0], [1.1, 1.0]]) >>> relations = torch.tensor([0, 0, 1]) # First two are same relation type >>> >>> logits, labels = in_batch_relationwise_contrastive_similarity( ... src_emb, dst_emb, relations, num_negatives=1 ... ) >>> # logits.shape: [3, 2] (positive + 1 negative per example) >>> # labels.shape: [3, 2] with first column all 1s, second column all 0s """ B, D = src_embeddings.shape # Compute similarity matrix between src and dst embeddings sim_fn = scoring_function.get_similarity_fn() sim_matrix = sim_fn(src_embeddings, dst_embeddings) sim_matrix = sim_matrix / temperature # [B, B] # Create mask indicating valid relations (same relation type) rel_mask = condensed_edge_types[:, None] == condensed_edge_types[None, :] # [B, B] # Identity matrix for diagonal masking (positive pair) identity = torch.diag(torch.ones_like(condensed_edge_types, dtype=torch.bool)) # Process based on corruption side: "src", "dst", or "both" if corrupt_side == NegativeSamplingCorruptionType.SRC: # In "src" corruption, we modify the source side, keeping the destination side fixed sim_matrix = sim_matrix.T # [B, B] -> Now rows are dst, columns are src elif corrupt_side == NegativeSamplingCorruptionType.DST: # In "dst" corruption, we modify the destination side, keeping the source side fixed # No change to sim_matrix, this is the default behavior pass elif corrupt_side == NegativeSamplingCorruptionType.BOTH: # In "both" corruption, we randomly decide for each row whether to corrupt the src or dst # Randomly decide to corrupt "src" or "dst" for each example (50% chance each) is_src_corruption = torch.rand_like( condensed_edge_types, dtype=torch.float32 ).bool() # [B] Mask for src corruption # Corrupt the source (flip relation mask and similarity matrix for those cases) sim_matrix_src = sim_matrix.T # [B, B] -> Now rows are dst, columns are src # Corrupt the destination (standard sim_matrix dst corruption) sim_matrix_dst = sim_matrix # Combine the two corruptions sim_matrix = torch.where( is_src_corruption.unsqueeze(1), sim_matrix_src, sim_matrix_dst ) else: raise ValueError( f"Corruption type must be in {NegativeSamplingCorruptionType.SRC, NegativeSamplingCorruptionType.DST, NegativeSamplingCorruptionType.BOTH}; got {corrupt_side}." ) # Mask invalid negatives (i.e., non-matching relations and diagonal) neg_mask = rel_mask & ~identity # Mask for valid negative pairs # Get positive logits (diagonal of the similarity matrix) pos_logits = sim_matrix.diagonal().unsqueeze(1) # [B, 1] # Mask the similarity matrix to only keep valid negatives logits_masked = sim_matrix.masked_fill(~neg_mask, float("-inf")) # [B, B] if num_negatives is None: # If no negative sampling, use all valid negatives logits = torch.cat([pos_logits, logits_masked], dim=1) labels = torch.zeros_like(logits, dtype=torch.float) labels = labels.scatter( 1, torch.zeros_like(condensed_edge_types, dtype=torch.long).view(-1, 1), 1 ) # Set positive labels to 1 (first column) return logits, labels # ---- Fully tensorized negative sampling ---- # Generate random scores for sampling negative pairs rand = torch.rand_like(logits_masked) # [B, B] rand.masked_fill_( ~neg_mask, float("inf") ) # Set invalid positions to +inf so they won't be selected in topk # Sample negatives using topk: smallest random scores are selected K = num_negatives sampled_idx = rand.topk(K, dim=1, largest=False, sorted=False).indices # [B, K] # Gather negative logits based on the sampled indices neg_logits = logits_masked.gather(1, sampled_idx) # [B, K] gather negative logits # Concatenate positive logits with negative logits logits = torch.cat([pos_logits, neg_logits], dim=1) # [B, 1 + K] labels = torch.zeros_like(logits, dtype=torch.float) labels = labels.scatter( 1, torch.zeros_like(condensed_edge_types, dtype=torch.long).view(-1, 1), 1 ) # Set positive labels to 1 (first column) return logits, labels
[docs] def against_batch_relationwise_contrastive_similarity( positive_src_embeddings: torch.Tensor, # [B, D] positive_dst_embeddings: torch.Tensor, # [B, D] positive_condensed_edge_types: torch.Tensor, # [B] negative_batch_src_embeddings: torch.Tensor, # [N, D] negative_batch_dst_embeddings: torch.Tensor, # [N, D] batch_condensed_edge_types: torch.Tensor, # [N] temperature: float = 1.0, scoring_function: SimilarityType = SimilarityType.COSINE, corrupt_side: NegativeSamplingCorruptionType = NegativeSamplingCorruptionType.DST, num_negatives: Optional[int] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Computes relation-aware contrastive similarity using an external batch of negative examples. This function extends contrastive learning beyond the current batch by using a separate, larger collection of potential negative examples. This approach often leads to better embedding quality since it provides more diverse and challenging negative examples. **Key Difference from In-Batch Sampling:** While `in_batch_relationwise_contrastive_similarity` creates negatives by reshuffling examples within the current batch, this function uses a completely separate set of pre-computed embeddings as negative candidates. This allows for: - More diverse negative examples - Better control over negative example quality - Ability to use hard negatives (examples that are difficult to distinguish from positives) - Larger pools of negatives without increasing batch size **Use Cases:** - When you have a large corpus of pre-computed embeddings to use as negatives - When implementing sophisticated negative sampling strategies (e.g., hard negatives) - When memory constraints limit your batch size but you want many negative examples - When training with cached embeddings from previous epochs Args: positive_src_embeddings (torch.Tensor): [B, D] Source node embeddings for the positive examples. These are the "query" embeddings that we want to find good matches for. B is the number of positive examples, D is the embedding dimension. positive_dst_embeddings (torch.Tensor): [B, D] Destination node embeddings for the positive examples. These are the true "target" embeddings that should be similar to their corresponding source embeddings. Each row pairs with the corresponding row in positive_src_embeddings. positive_condensed_edge_types (torch.Tensor): [B] Relation type IDs for the positive examples. Integer identifiers specifying which relation type connects each source-destination pair. This ensures that negative examples are only selected from the same relation type. negative_batch_src_embeddings (torch.Tensor): [N, D] Source node embeddings from an external batch that will serve as potential negative examples. N is typically much larger than B, providing a rich pool of negative candidates. These embeddings might come from: - A different batch from the same dataset - Pre-computed embeddings from earlier training steps - A carefully curated set of hard negative examples negative_batch_dst_embeddings (torch.Tensor): [N, D] Destination node embeddings from the external batch. These correspond to the source embeddings and will be used when corrupting the destination side of triples. negative_batch_condensed_edge_types (torch.Tensor): [N] Relation type IDs for the external batch. Used to ensure that negative examples maintain relation-type consistency. Only embeddings with matching relation types will be considered as valid negatives. temperature (float, optional): Temperature parameter for similarity scaling. Controls the "sharpness" of the resulting probability distribution: - Lower values (< 1.0): Make the model more confident, sharper distinctions - Higher values (> 1.0): Make the model less confident, smoother distributions - 1.0: No scaling applied Defaults to 1.0. scoring_function (SimilarityType, optional): Method for computing embedding similarity: - COSINE: Normalized dot product, measures angle between vectors (most common) - DOT: Raw dot product, sensitive to vector magnitudes - EUCLIDEAN: Negative squared L2 distance, measures geometric distance The choice affects how the model learns to represent relationships. Defaults to SimilarityType.COSINE. corrupt_side (NegativeSamplingCorruptionType, optional): Specifies which part of the knowledge graph triple to replace when creating negative examples: - DST: Replace destination nodes (e.g., (Albert_Einstein, "born_in", wrong_city)) - SRC: Replace source nodes (e.g., (wrong_person, "born_in", Germany)) - BOTH: Randomly choose to replace either source or destination for each example Different corruption strategies can lead to different learned representations. Defaults to NegativeSamplingCorruptionType.DST. num_negatives (Optional[int], optional): Maximum number of negative examples to use per positive example. Controls the computational/memory trade-off: - None: Use all valid negatives from the external batch (can be expensive) - Small number (5-20): Fast training, fewer negatives - Large number (100+): Slower but potentially better quality learning Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: logits (torch.Tensor): [B, 1 + K] Similarity scores matrix where: - First column: Similarities between true positive pairs (src_i, dst_i) - Remaining K columns: Similarities between each positive and its K negative examples - Higher values indicate higher predicted similarity - Used as input to contrastive loss functions labels (torch.Tensor): [B, 1 + K] Binary label matrix corresponding to logits: - First column: All 1s (indicating true positive pairs) - Remaining K columns: All 0s (indicating negative pairs) - Used as targets for contrastive loss computation - Shape matches logits for element-wise loss calculation Example: >>> # 2 positive examples, 5 external candidates for negatives >>> pos_src = torch.randn(2, 128) # 2 positive source embeddings >>> pos_dst = torch.randn(2, 128) # 2 positive destination embeddings >>> pos_rels = torch.tensor([0, 1]) # Different relation types >>> >>> neg_src = torch.randn(5, 128) # 5 potential negative sources >>> neg_dst = torch.randn(5, 128) # 5 potential negative destinations >>> neg_rels = torch.tensor([0, 0, 1, 1, 2]) # Mixed relation types >>> >>> logits, labels = against_batch_relationwise_contrastive_similarity( ... pos_src, pos_dst, pos_rels, ... neg_src, neg_dst, neg_rels, ... num_negatives=2 # Sample 2 negatives per positive ... ) >>> # logits.shape: [2, 3] (1 positive + 2 negatives per example) >>> # labels.shape: [2, 3] (first column 1s, others 0s) >>> # Only relation-type-matching negatives are selected Note: This function is particularly useful in advanced training scenarios where you want fine-grained control over negative sampling, such as curriculum learning, hard negative mining, or when working with very large knowledge graphs where in-batch sampling provides insufficient diversity. """ B, D = positive_src_embeddings.shape N = negative_batch_src_embeddings.shape[0] # Precompute similarity matrix between [B] queries and [N] candidates sim_fn = scoring_function.get_similarity_fn() # Build masks for valid negatives per relation type # [B, N]: True where edge types match rel_mask = ( positive_condensed_edge_types[:, None] == batch_condensed_edge_types[None, :] ) # Positive similarity pos_logits = ( sim_fn(positive_src_embeddings, positive_dst_embeddings).diagonal().unsqueeze(1) ) # [B, 1] # Negative similarity matrix (B x N), depends on corruption side if corrupt_side == NegativeSamplingCorruptionType.SRC: neg_sim_matrix = sim_fn( negative_batch_src_embeddings, positive_dst_embeddings ) # [N, B] neg_sim_matrix = neg_sim_matrix.T # [B, N] elif corrupt_side == NegativeSamplingCorruptionType.DST: neg_sim_matrix = sim_fn( positive_src_embeddings, negative_batch_dst_embeddings ) # [B, N] elif corrupt_side == NegativeSamplingCorruptionType.BOTH: is_src_corruption = torch.rand_like( positive_condensed_edge_types, dtype=torch.float32 ).bool() # [B] Mask for src corruption neg_sim_src = sim_fn( negative_batch_src_embeddings, positive_dst_embeddings ).T # [B, N] neg_sim_dst = sim_fn( positive_src_embeddings, negative_batch_dst_embeddings ) # [B, N] neg_sim_matrix = torch.where( is_src_corruption.unsqueeze(1), neg_sim_src, neg_sim_dst ) else: raise ValueError(f"Invalid corrupt_side: {corrupt_side}") neg_sim_matrix = neg_sim_matrix / temperature # [B, N] # Mask invalid negatives (i.e., non-matching relations) logits_masked = neg_sim_matrix.masked_fill(~rel_mask, float("-inf")) # [B, N] if num_negatives is None: # If no negative sampling, use all valid negatives logits = torch.cat([pos_logits, logits_masked], dim=1) labels = torch.zeros_like(logits, dtype=torch.float) labels = labels.scatter( 1, torch.zeros_like(positive_condensed_edge_types, dtype=torch.long).view( -1, 1 ), 1, ) # Set positive labels to 1 (first column) return logits, labels # Sample K negatives per row from matching relation types rand = torch.rand_like(logits_masked) # [B, N] rand.masked_fill_(~rel_mask, float("inf")) # Prevent selecting mismatched negatives sampled_idx = rand.topk(num_negatives, dim=1, largest=False).indices # [B, K] # Gather negative similarities neg_logits = logits_masked.gather(1, sampled_idx) # [B, K] gather negative logits # Concatenate positive logits with negative logits logits = torch.cat([pos_logits, neg_logits], dim=1) # [B, 1 + K] labels = torch.zeros_like(logits, dtype=torch.float) labels = labels.scatter( 1, torch.zeros_like(positive_condensed_edge_types, dtype=torch.long).view(-1, 1), 1, ) # Set positive labels to 1 (first column) return logits, labels