from typing import Optional
import torch
from gigl.experimental.knowledge_graph_embedding.lib.model.types import (
NegativeSamplingCorruptionType,
SimilarityType,
)
[docs]
def in_batch_relationwise_contrastive_similarity(
src_embeddings: torch.Tensor, # [B, D]
dst_embeddings: torch.Tensor, # [B, D]
condensed_edge_types: torch.Tensor, # [B]
temperature: float = 1.0,
scoring_function: SimilarityType = SimilarityType.COSINE,
corrupt_side: NegativeSamplingCorruptionType = NegativeSamplingCorruptionType.DST,
num_negatives: Optional[int] = None, # Number of negatives to sample per instance
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Computes relation-aware in-batch contrastive similarity for knowledge graph embedding training.
This function implements contrastive learning for knowledge graph embeddings, where the goal is
to learn node representations such that related nodes (connected by edges) have similar embeddings
while unrelated nodes have dissimilar embeddings.
**Background:**
In knowledge graph embedding, we represent entities (nodes) and relations (edge types) as dense
vectors. For a triple (source, relation, destination), we want the embeddings of source and
destination to be similar when they're connected by that relation type.
**Contrastive Learning:**
This function uses "in-batch" contrastive learning, meaning it creates negative examples by
pairing each positive example with other examples in the same batch that have the same relation
type but different source/destination nodes. This is computationally efficient since it reuses
embeddings already computed for the batch.
**Relation-aware:**
The "relation-aware" aspect means that negative examples are only created within the same
relation type. For example, if we have a positive example (person_A, "lives_in", city_B),
we only create negatives with other "lives_in" relations, not with "works_for" relations.
Args:
src_embeddings (torch.Tensor): [B, D] Source node embeddings for each positive example.
B is the batch size, D is the embedding dimension. Each row represents the embedding
of a source entity in a knowledge graph triple.
dst_embeddings (torch.Tensor): [B, D] Destination node embeddings for each positive example.
Each row represents the embedding of a destination entity that should be similar to
its corresponding source entity.
condensed_edge_types (torch.Tensor): [B] Integer relation type IDs for each positive example.
This identifies which relation type connects each source-destination pair. Examples
within the same relation type can be used as negatives for each other.
temperature (float, optional): Temperature parameter for scaling similarities before softmax.
Lower values (< 1.0) make the model more confident in its predictions by sharpening
the probability distribution. Higher values (> 1.0) make predictions more uniform.
Defaults to 1.0.
scoring_function (SimilarityType, optional): Function used to compute similarity between
embeddings. Options are:
- COSINE: Cosine similarity (angle between vectors, normalized)
- DOT: Dot product (unnormalized, sensitive to magnitude)
- EUCLIDEAN: Negative squared Euclidean distance (closer = more similar)
Defaults to SimilarityType.COSINE.
corrupt_side (NegativeSamplingCorruptionType, optional): Which side of the triple to corrupt
when creating negative examples:
- DST: Replace destination nodes (e.g., (person_A, "lives_in", wrong_city))
- SRC: Replace source nodes (e.g., (wrong_person, "lives_in", city_B))
- BOTH: Randomly choose to corrupt either source or destination for each example
Defaults to NegativeSamplingCorruptionType.DST.
num_negatives (Optional[int], optional): Number of negative examples to sample per positive
example. If None, uses all valid negatives in the batch (can be computationally expensive
for large batches). Setting a specific number (e.g., 10) makes training more efficient.
Defaults to None.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
logits (torch.Tensor): [B, 1 + K] Similarity scores for positive and negative pairs.
The first column contains similarities for the true positive pairs. The remaining
K columns contain similarities for the negative pairs. Higher values indicate
higher similarity.
labels (torch.Tensor): [B, 1 + K] Binary labels corresponding to the logits.
The first column is all 1s (indicating positive pairs), and the remaining
columns are all 0s (indicating negative pairs). Used for computing contrastive loss.
Example:
>>> # Batch of 3 examples with 2D embeddings
>>> src_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
>>> dst_emb = torch.tensor([[1.0, 0.1], [0.1, 1.0], [1.1, 1.0]])
>>> relations = torch.tensor([0, 0, 1]) # First two are same relation type
>>>
>>> logits, labels = in_batch_relationwise_contrastive_similarity(
... src_emb, dst_emb, relations, num_negatives=1
... )
>>> # logits.shape: [3, 2] (positive + 1 negative per example)
>>> # labels.shape: [3, 2] with first column all 1s, second column all 0s
"""
B, D = src_embeddings.shape
# Compute similarity matrix between src and dst embeddings
sim_fn = scoring_function.get_similarity_fn()
sim_matrix = sim_fn(src_embeddings, dst_embeddings)
sim_matrix = sim_matrix / temperature # [B, B]
# Create mask indicating valid relations (same relation type)
rel_mask = condensed_edge_types[:, None] == condensed_edge_types[None, :] # [B, B]
# Identity matrix for diagonal masking (positive pair)
identity = torch.diag(torch.ones_like(condensed_edge_types, dtype=torch.bool))
# Process based on corruption side: "src", "dst", or "both"
if corrupt_side == NegativeSamplingCorruptionType.SRC:
# In "src" corruption, we modify the source side, keeping the destination side fixed
sim_matrix = sim_matrix.T # [B, B] -> Now rows are dst, columns are src
elif corrupt_side == NegativeSamplingCorruptionType.DST:
# In "dst" corruption, we modify the destination side, keeping the source side fixed
# No change to sim_matrix, this is the default behavior
pass
elif corrupt_side == NegativeSamplingCorruptionType.BOTH:
# In "both" corruption, we randomly decide for each row whether to corrupt the src or dst
# Randomly decide to corrupt "src" or "dst" for each example (50% chance each)
is_src_corruption = torch.rand_like(
condensed_edge_types, dtype=torch.float32
).bool() # [B] Mask for src corruption
# Corrupt the source (flip relation mask and similarity matrix for those cases)
sim_matrix_src = sim_matrix.T # [B, B] -> Now rows are dst, columns are src
# Corrupt the destination (standard sim_matrix dst corruption)
sim_matrix_dst = sim_matrix
# Combine the two corruptions
sim_matrix = torch.where(
is_src_corruption.unsqueeze(1), sim_matrix_src, sim_matrix_dst
)
else:
raise ValueError(
f"Corruption type must be in {NegativeSamplingCorruptionType.SRC, NegativeSamplingCorruptionType.DST, NegativeSamplingCorruptionType.BOTH}; got {corrupt_side}."
)
# Mask invalid negatives (i.e., non-matching relations and diagonal)
neg_mask = rel_mask & ~identity # Mask for valid negative pairs
# Get positive logits (diagonal of the similarity matrix)
pos_logits = sim_matrix.diagonal().unsqueeze(1) # [B, 1]
# Mask the similarity matrix to only keep valid negatives
logits_masked = sim_matrix.masked_fill(~neg_mask, float("-inf")) # [B, B]
if num_negatives is None:
# If no negative sampling, use all valid negatives
logits = torch.cat([pos_logits, logits_masked], dim=1)
labels = torch.zeros_like(logits, dtype=torch.float)
labels = labels.scatter(
1, torch.zeros_like(condensed_edge_types, dtype=torch.long).view(-1, 1), 1
) # Set positive labels to 1 (first column)
return logits, labels
# ---- Fully tensorized negative sampling ----
# Generate random scores for sampling negative pairs
rand = torch.rand_like(logits_masked) # [B, B]
rand.masked_fill_(
~neg_mask, float("inf")
) # Set invalid positions to +inf so they won't be selected in topk
# Sample negatives using topk: smallest random scores are selected
K = num_negatives
sampled_idx = rand.topk(K, dim=1, largest=False, sorted=False).indices # [B, K]
# Gather negative logits based on the sampled indices
neg_logits = logits_masked.gather(1, sampled_idx) # [B, K] gather negative logits
# Concatenate positive logits with negative logits
logits = torch.cat([pos_logits, neg_logits], dim=1) # [B, 1 + K]
labels = torch.zeros_like(logits, dtype=torch.float)
labels = labels.scatter(
1, torch.zeros_like(condensed_edge_types, dtype=torch.long).view(-1, 1), 1
) # Set positive labels to 1 (first column)
return logits, labels
[docs]
def against_batch_relationwise_contrastive_similarity(
positive_src_embeddings: torch.Tensor, # [B, D]
positive_dst_embeddings: torch.Tensor, # [B, D]
positive_condensed_edge_types: torch.Tensor, # [B]
negative_batch_src_embeddings: torch.Tensor, # [N, D]
negative_batch_dst_embeddings: torch.Tensor, # [N, D]
batch_condensed_edge_types: torch.Tensor, # [N]
temperature: float = 1.0,
scoring_function: SimilarityType = SimilarityType.COSINE,
corrupt_side: NegativeSamplingCorruptionType = NegativeSamplingCorruptionType.DST,
num_negatives: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Computes relation-aware contrastive similarity using an external batch of negative examples.
This function extends contrastive learning beyond the current batch by using a separate,
larger collection of potential negative examples. This approach often leads to better
embedding quality since it provides more diverse and challenging negative examples.
**Key Difference from In-Batch Sampling:**
While `in_batch_relationwise_contrastive_similarity` creates negatives by reshuffling
examples within the current batch, this function uses a completely separate set of
pre-computed embeddings as negative candidates. This allows for:
- More diverse negative examples
- Better control over negative example quality
- Ability to use hard negatives (examples that are difficult to distinguish from positives)
- Larger pools of negatives without increasing batch size
**Use Cases:**
- When you have a large corpus of pre-computed embeddings to use as negatives
- When implementing sophisticated negative sampling strategies (e.g., hard negatives)
- When memory constraints limit your batch size but you want many negative examples
- When training with cached embeddings from previous epochs
Args:
positive_src_embeddings (torch.Tensor): [B, D] Source node embeddings for the positive examples.
These are the "query" embeddings that we want to find good matches for. B is the
number of positive examples, D is the embedding dimension.
positive_dst_embeddings (torch.Tensor): [B, D] Destination node embeddings for the positive examples.
These are the true "target" embeddings that should be similar to their corresponding
source embeddings. Each row pairs with the corresponding row in positive_src_embeddings.
positive_condensed_edge_types (torch.Tensor): [B] Relation type IDs for the positive examples.
Integer identifiers specifying which relation type connects each source-destination
pair. This ensures that negative examples are only selected from the same relation type.
negative_batch_src_embeddings (torch.Tensor): [N, D] Source node embeddings from an external batch
that will serve as potential negative examples. N is typically much larger than B,
providing a rich pool of negative candidates. These embeddings might come from:
- A different batch from the same dataset
- Pre-computed embeddings from earlier training steps
- A carefully curated set of hard negative examples
negative_batch_dst_embeddings (torch.Tensor): [N, D] Destination node embeddings from the external
batch. These correspond to the source embeddings and will be used when corrupting
the destination side of triples.
negative_batch_condensed_edge_types (torch.Tensor): [N] Relation type IDs for the external batch.
Used to ensure that negative examples maintain relation-type consistency. Only
embeddings with matching relation types will be considered as valid negatives.
temperature (float, optional): Temperature parameter for similarity scaling. Controls
the "sharpness" of the resulting probability distribution:
- Lower values (< 1.0): Make the model more confident, sharper distinctions
- Higher values (> 1.0): Make the model less confident, smoother distributions
- 1.0: No scaling applied
Defaults to 1.0.
scoring_function (SimilarityType, optional): Method for computing embedding similarity:
- COSINE: Normalized dot product, measures angle between vectors (most common)
- DOT: Raw dot product, sensitive to vector magnitudes
- EUCLIDEAN: Negative squared L2 distance, measures geometric distance
The choice affects how the model learns to represent relationships.
Defaults to SimilarityType.COSINE.
corrupt_side (NegativeSamplingCorruptionType, optional): Specifies which part of the
knowledge graph triple to replace when creating negative examples:
- DST: Replace destination nodes (e.g., (Albert_Einstein, "born_in", wrong_city))
- SRC: Replace source nodes (e.g., (wrong_person, "born_in", Germany))
- BOTH: Randomly choose to replace either source or destination for each example
Different corruption strategies can lead to different learned representations.
Defaults to NegativeSamplingCorruptionType.DST.
num_negatives (Optional[int], optional): Maximum number of negative examples to use per
positive example. Controls the computational/memory trade-off:
- None: Use all valid negatives from the external batch (can be expensive)
- Small number (5-20): Fast training, fewer negatives
- Large number (100+): Slower but potentially better quality learning
Defaults to None.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
logits (torch.Tensor): [B, 1 + K] Similarity scores matrix where:
- First column: Similarities between true positive pairs (src_i, dst_i)
- Remaining K columns: Similarities between each positive and its K negative examples
- Higher values indicate higher predicted similarity
- Used as input to contrastive loss functions
labels (torch.Tensor): [B, 1 + K] Binary label matrix corresponding to logits:
- First column: All 1s (indicating true positive pairs)
- Remaining K columns: All 0s (indicating negative pairs)
- Used as targets for contrastive loss computation
- Shape matches logits for element-wise loss calculation
Example:
>>> # 2 positive examples, 5 external candidates for negatives
>>> pos_src = torch.randn(2, 128) # 2 positive source embeddings
>>> pos_dst = torch.randn(2, 128) # 2 positive destination embeddings
>>> pos_rels = torch.tensor([0, 1]) # Different relation types
>>>
>>> neg_src = torch.randn(5, 128) # 5 potential negative sources
>>> neg_dst = torch.randn(5, 128) # 5 potential negative destinations
>>> neg_rels = torch.tensor([0, 0, 1, 1, 2]) # Mixed relation types
>>>
>>> logits, labels = against_batch_relationwise_contrastive_similarity(
... pos_src, pos_dst, pos_rels,
... neg_src, neg_dst, neg_rels,
... num_negatives=2 # Sample 2 negatives per positive
... )
>>> # logits.shape: [2, 3] (1 positive + 2 negatives per example)
>>> # labels.shape: [2, 3] (first column 1s, others 0s)
>>> # Only relation-type-matching negatives are selected
Note:
This function is particularly useful in advanced training scenarios where you want
fine-grained control over negative sampling, such as curriculum learning, hard
negative mining, or when working with very large knowledge graphs where in-batch
sampling provides insufficient diversity.
"""
B, D = positive_src_embeddings.shape
N = negative_batch_src_embeddings.shape[0]
# Precompute similarity matrix between [B] queries and [N] candidates
sim_fn = scoring_function.get_similarity_fn()
# Build masks for valid negatives per relation type
# [B, N]: True where edge types match
rel_mask = (
positive_condensed_edge_types[:, None] == batch_condensed_edge_types[None, :]
)
# Positive similarity
pos_logits = (
sim_fn(positive_src_embeddings, positive_dst_embeddings).diagonal().unsqueeze(1)
) # [B, 1]
# Negative similarity matrix (B x N), depends on corruption side
if corrupt_side == NegativeSamplingCorruptionType.SRC:
neg_sim_matrix = sim_fn(
negative_batch_src_embeddings, positive_dst_embeddings
) # [N, B]
neg_sim_matrix = neg_sim_matrix.T # [B, N]
elif corrupt_side == NegativeSamplingCorruptionType.DST:
neg_sim_matrix = sim_fn(
positive_src_embeddings, negative_batch_dst_embeddings
) # [B, N]
elif corrupt_side == NegativeSamplingCorruptionType.BOTH:
is_src_corruption = torch.rand_like(
positive_condensed_edge_types, dtype=torch.float32
).bool() # [B] Mask for src corruption
neg_sim_src = sim_fn(
negative_batch_src_embeddings, positive_dst_embeddings
).T # [B, N]
neg_sim_dst = sim_fn(
positive_src_embeddings, negative_batch_dst_embeddings
) # [B, N]
neg_sim_matrix = torch.where(
is_src_corruption.unsqueeze(1), neg_sim_src, neg_sim_dst
)
else:
raise ValueError(f"Invalid corrupt_side: {corrupt_side}")
neg_sim_matrix = neg_sim_matrix / temperature # [B, N]
# Mask invalid negatives (i.e., non-matching relations)
logits_masked = neg_sim_matrix.masked_fill(~rel_mask, float("-inf")) # [B, N]
if num_negatives is None:
# If no negative sampling, use all valid negatives
logits = torch.cat([pos_logits, logits_masked], dim=1)
labels = torch.zeros_like(logits, dtype=torch.float)
labels = labels.scatter(
1,
torch.zeros_like(positive_condensed_edge_types, dtype=torch.long).view(
-1, 1
),
1,
) # Set positive labels to 1 (first column)
return logits, labels
# Sample K negatives per row from matching relation types
rand = torch.rand_like(logits_masked) # [B, N]
rand.masked_fill_(~rel_mask, float("inf")) # Prevent selecting mismatched negatives
sampled_idx = rand.topk(num_negatives, dim=1, largest=False).indices # [B, K]
# Gather negative similarities
neg_logits = logits_masked.gather(1, sampled_idx) # [B, K] gather negative logits
# Concatenate positive logits with negative logits
logits = torch.cat([pos_logits, neg_logits], dim=1) # [B, 1 + K]
labels = torch.zeros_like(logits, dtype=torch.float)
labels = labels.scatter(
1,
torch.zeros_like(positive_condensed_edge_types, dtype=torch.long).view(-1, 1),
1,
) # Set positive labels to 1 (first column)
return logits, labels