Source code for gigl.experimental.knowledge_graph_embedding.lib.model.heterogeneous_graph_model

from typing import Callable, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchrec

from gigl.common.logger import Logger
from gigl.experimental.knowledge_graph_embedding.common.torchrec.large_embedding_lookup import (
    LargeEmbeddingLookup,
)
from gigl.experimental.knowledge_graph_embedding.lib.config import ModelConfig
from gigl.experimental.knowledge_graph_embedding.lib.config.sampling import (
    SamplingConfig,
)
from gigl.experimental.knowledge_graph_embedding.lib.data.edge_batch import EdgeBatch
from gigl.experimental.knowledge_graph_embedding.lib.data.node_batch import NodeBatch
from gigl.experimental.knowledge_graph_embedding.lib.model.negative_sampling import (
    against_batch_relationwise_contrastive_similarity,
    in_batch_relationwise_contrastive_similarity,
)
from gigl.experimental.knowledge_graph_embedding.lib.model.types import (
    ModelPhase,
    SimilarityType,
)

[docs] logger = Logger()
# TODO(nshah): This could be refactored to be more modular and have individualized APIs for individual KGE model variants.
[docs] class HeterogeneousGraphSparseEmbeddingModel(nn.Module): """ A backbone model to support sparse embedding of (possibly multi-relational) graphs. Can also be used to implement matrix factorization and variants. Useful overviews on Knowledge Graph Embedding: - Knowledge Graph Embedding: An Overview (Ge et al, 2023): https://arxiv.org/pdf/2309.12501 - Stanford CS224W: ML with Graphs: Knowledge Graph Embeddings (2023): https://www.youtube.com/watch?v=isI_TUMoP60 Args: model_config (ModelConfig): Configuration object containing model parameters. """ def __init__(self, model_config: ModelConfig): """Initialize the model with the given embedding configurations.""" super().__init__()
[docs] self.num_edge_types = model_config.num_edge_types
[docs] self.node_emb_dim = model_config.node_embedding_dim
[docs] self.training_sampling_config = model_config.training_sampling
[docs] self.validation_sampling_config = model_config.validation_sampling
[docs] self.testing_sampling_config = model_config.testing_sampling
[docs] self.similarity_type = model_config.embedding_similarity_type
self._phase: ModelPhase = ModelPhase.TRAIN self._assert_sampling_config_is_valid() # Define the embedding layers. assert ( model_config.embeddings_config is not None ), "Embedding config must be provided."
[docs] self.large_embeddings = LargeEmbeddingLookup( embeddings_config=model_config.embeddings_config )
# Define the operators applied to src and dst node types respectively assert self.num_edge_types is not None, "Number of edge types must be provided."
[docs] self.src_operator = model_config.src_operator.get_corresponding_module()( num_edge_types=self.num_edge_types, node_emb_dim=self.node_emb_dim, )
[docs] self.dst_operator = model_config.dst_operator.get_corresponding_module()( num_edge_types=self.num_edge_types, node_emb_dim=self.node_emb_dim, )
logger.info(f"Initialized model with: {self.__dict__}") def _assert_sampling_config_is_valid(self): for sampling_config in ( self.training_sampling_config, self.validation_sampling_config, self.testing_sampling_config, ): assert sampling_config is not None, "Sampling config must be provided." assert ( sampling_config.positive_edge_batch_size > 0 ), "Positive edge batch size must be greater than 0." assert ( sampling_config.num_inbatch_negatives_per_edge + sampling_config.num_random_negatives_per_edge > 0 ), "At least one type of negative sampling must be specified." @property
[docs] def active_sampling_config(self) -> SamplingConfig: if self.phase == ModelPhase.TRAIN: return self.training_sampling_config # type: ignore elif self.phase == ModelPhase.VAL: return self.validation_sampling_config # type: ignore elif self.phase == ModelPhase.TEST: return self.testing_sampling_config # type: ignore elif ( self.phase == ModelPhase.INFERENCE_SRC or self.phase == ModelPhase.INFERENCE_DST ): raise ValueError( "Active sampling config is not defined for inference phase. " ) else: raise ValueError( f"Unknown model phase: {self.phase}. Cannot determine active sampling config." )
[docs] def set_phase(self, phase: ModelPhase): """ Set the phase of the model. This is used to determine which sampling configuration to use during training, validation, or testing. Note that this affects (i) how data that is passed into the model is interpreted (e.g. #s of positives, negatives) (ii) whether inbatch negatives are used to compute logits and labels Args: phase (ModelPhase): The current phase of the model (TRAIN, VALIDATION, TEST). """ old_phase = self._phase self._phase = phase logger.info(f"Changed model phase from {old_phase} to {phase}")
@property
[docs] def phase(self) -> ModelPhase: return self._phase
[docs] def fetch_src_and_dst_embeddings( self, edge_batch: EdgeBatch ) -> tuple[torch.Tensor, torch.Tensor]: num_edges = edge_batch.batch_size node_embeddings_kt: torchrec.KeyedTensor = self.large_embeddings( edge_batch.src_dst_pairs ) logger.debug(f"node embeddings kt: {node_embeddings_kt}") node_embeddings = ( node_embeddings_kt.values() ) # [2 * num_edges, num_node_types * node_dim] logger.debug( f"node embeddings kt values: {node_embeddings, node_embeddings.shape}" ) node_embeddings = node_embeddings.reshape( 2 * num_edges, -1, self.node_emb_dim ) # [2 * num_edges, num_node_types, node_dim] logger.debug( f"node embeddings values reshaped: {node_embeddings, node_embeddings.shape}" ) node_embeddings = node_embeddings.sum(dim=1) # [2 * num_edges, node_dim] logger.debug( f"node embeddings collapse middle axis: {node_embeddings, node_embeddings.shape}" ) node_embeddings = node_embeddings.reshape( num_edges, 2, self.node_emb_dim ) # [num_edges, 2, node_dim] logger.debug( f"node embeddings reshape into correct tensor: {node_embeddings, node_embeddings.shape}" ) src_embeddings = node_embeddings[:, 0, :] dst_embeddings = node_embeddings[:, 1, :] return src_embeddings, dst_embeddings
[docs] def apply_relation_operator( self, src_embeddings: torch.Tensor, dst_embeddings: torch.Tensor, condensed_edge_types: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply the src and dst relation operators to the source and destination embeddings. Some reasonable configurations to reimplement common KG embedding algorithms: - TransE: Translation on src embeddings, dst embeddings remain unchanged - CompleX: Complex diagonal on src embeddings, dst embeddings remain unchanged - DistMult: Diagonal on src embeddings, dst embeddings remain unchanged - RESCAL: Linear on src embeddings, dst embeddings remain unchanged This can also be used to implement things like raw Matrix Factorization by using identity operators, or other custom operators. Args: src_embeddings: Source node embeddings. dst_embeddings: Destination node embeddings. condensed_edge_types: Edge types for the current batch. Returns: Tuple of transformed source and destination embeddings. """ # Apply the src operator to the source embeddings src_embeddings = self.src_operator( embeddings=src_embeddings, condensed_edge_types=condensed_edge_types, ) # Apply the dst operator to the destination embeddings dst_embeddings = self.dst_operator( embeddings=dst_embeddings, condensed_edge_types=condensed_edge_types, ) return src_embeddings, dst_embeddings
[docs] def score_edges( self, src_embeddings: torch.Tensor, dst_embeddings: torch.Tensor, ): # Compute the scores using the specified scoring function if self.similarity_type == SimilarityType.DOT: scores = torch.sum(src_embeddings * dst_embeddings, dim=1) elif self.similarity_type == SimilarityType.COSINE: scores = F.cosine_similarity(src_embeddings, dst_embeddings, dim=1) else: raise ValueError(f"Unknown scoring function: {self.similarity_type}") return scores
[docs] def infer_node_batch( self, node_batch: NodeBatch, ) -> torch.Tensor: """ Infer node embeddings for a given NodeBatch. Args: node_batch (NodeBatch): The batch of nodes to infer embeddings for. Returns: torch.Tensor: The inferred node embeddings. """ # Fetch node embeddings from the embedding layer num_nodes = node_batch.nodes.values().numel() node_embeddings_kt: torchrec.KeyedTensor = self.large_embeddings( node_batch.nodes ) node_embeddings = node_embeddings_kt.values() node_embeddings = node_embeddings.reshape( num_nodes, -1, self.node_emb_dim ) # [num_nodes, num_node_types, node_dim] node_embeddings = node_embeddings.sum(dim=1) # [num_nodes, node_dim] operator = ( self.src_operator if self.phase == ModelPhase.INFERENCE_SRC else self.dst_operator ) # Apply the operator to the node embeddings node_embeddings = operator( embeddings=node_embeddings, condensed_edge_types=node_batch.condensed_edge_type.repeat(num_nodes), ) return node_embeddings
[docs] def forward( self, edge_batch: EdgeBatch ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Fetch node embeddings from the embedding layer src_embeddings, dst_embeddings = self.fetch_src_and_dst_embeddings(edge_batch) condensed_edge_types = edge_batch.condensed_edge_types labels = edge_batch.labels # Apply relation operator to transform embeddings based on edge types src_embeddings, dst_embeddings = self.apply_relation_operator( src_embeddings=src_embeddings, dst_embeddings=dst_embeddings, condensed_edge_types=condensed_edge_types, ) pos_src_embeddings, batch_neg_src_embeddings = ( src_embeddings[: self.active_sampling_config.positive_edge_batch_size], src_embeddings[self.active_sampling_config.positive_edge_batch_size :], ) pos_dst_embeddings, batch_neg_dst_embeddings = ( dst_embeddings[: self.active_sampling_config.positive_edge_batch_size], dst_embeddings[self.active_sampling_config.positive_edge_batch_size :], ) pos_condensed_edge_types, batch_neg_condensed_edge_types = ( condensed_edge_types[ : self.active_sampling_config.positive_edge_batch_size ], condensed_edge_types[ self.active_sampling_config.positive_edge_batch_size : ], ) pos_labels, batch_neg_labels = ( labels[: self.active_sampling_config.positive_edge_batch_size], labels[self.active_sampling_config.positive_edge_batch_size :], ) pos_logits: torch.Tensor = torch.tensor([]) neg_logits: list[torch.Tensor] = list() neg_labels: list[torch.Tensor] = list() if self.active_sampling_config.num_inbatch_negatives_per_edge: # Do inbatch negative sampling and compute logits and labels ( in_batch_logits, in_batch_labels, ) = in_batch_relationwise_contrastive_similarity( src_embeddings=pos_src_embeddings, dst_embeddings=pos_dst_embeddings, condensed_edge_types=pos_condensed_edge_types, scoring_function=self.similarity_type, corrupt_side=self.active_sampling_config.negative_corruption_side, num_negatives=self.active_sampling_config.num_inbatch_negatives_per_edge, ) pos_logits = in_batch_logits[:, 0].unsqueeze(1) pos_labels = in_batch_labels[:, 0].unsqueeze(1) neg_logits.append(in_batch_logits[:, 1:]) neg_labels.append(in_batch_labels[:, 1:]) if self.active_sampling_config.num_random_negatives_per_edge: ( against_batch_logits, against_batch_labels, ) = against_batch_relationwise_contrastive_similarity( positive_src_embeddings=pos_src_embeddings, positive_dst_embeddings=pos_dst_embeddings, positive_condensed_edge_types=pos_condensed_edge_types, negative_batch_src_embeddings=batch_neg_src_embeddings, negative_batch_dst_embeddings=batch_neg_dst_embeddings, batch_condensed_edge_types=batch_neg_condensed_edge_types, scoring_function=self.similarity_type, corrupt_side=self.active_sampling_config.negative_corruption_side, num_negatives=self.active_sampling_config.num_random_negatives_per_edge, ) # These pos_logits and pos_labels are the same as those from in-batch similarity calculations. # We keep them here for consistency. pos_logits = against_batch_logits[:, 0].unsqueeze(1) pos_labels = against_batch_labels[:, 0].unsqueeze(1) neg_logits.append(against_batch_logits[:, 1:]) neg_labels.append(against_batch_labels[:, 1:]) # Concatenate positive and negative samples all_neg_logits = torch.cat(neg_logits, dim=1) all_neg_labels = torch.cat(neg_labels, dim=1) logits = torch.cat([pos_logits, all_neg_logits], dim=1) labels = torch.cat([pos_labels, all_neg_labels], dim=1) return logits, labels, pos_condensed_edge_types
[docs] class HeterogeneousGraphSparseEmbeddingModelAndLoss(nn.Module): """ A simple heterogeneous information network model with loss. This module wraps the `HeterogeneousGraphSparseEmbeddingModel` model for use with torchrec TrainPipeline abstraction, which requires specific input/output expectations regarding loss and outputs. This is required by TorchRec's convention. For more details, see: - https://github.com/pytorch/torchrec/blob/3ec6f537bf230556b58f5a527ed32e23cc50849d/examples/golden_training/train_dlrm.py#L111 - https://github.com/pytorch/torchrec/blob/3ec6f537bf230556b58f5a527ed32e23cc50849d/torchrec/models/dlrm.py#L850 """ def __init__( self, encoder_model: HeterogeneousGraphSparseEmbeddingModel, loss_fn: Callable[ [torch.Tensor, torch.Tensor], torch.Tensor ] = F.binary_cross_entropy_with_logits, ): """ Initialize the model with the given encoder model and loss function. Args: encoder_model (HeterogeneousGraphSparseEmbeddingModel): The underlying model for encoding. loss_fn (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): The loss function to compute the loss. Defaults to binary cross-entropy with logits. """ super().__init__()
[docs] self.encoder_model = encoder_model
[docs] self.loss_fn = loss_fn
[docs] def set_phase(self, phase: ModelPhase): """ Set the phase of the encoder model. This is used to determine which sampling configuration to use during training, validation, or testing. Note that this affects (i) how data that is passed into the model is interpreted (e.g. #s of positives, negatives) (ii) whether inbatch negatives are used to compute logits and labels Args: phase (ModelPhase): The current phase of the model (TRAIN, VAL, TEST, INFERENCE). """ self.encoder_model.set_phase(phase=phase)
@property
[docs] def phase(self) -> ModelPhase: return self.encoder_model.phase
[docs] def forward(self, batch: Union[EdgeBatch, NodeBatch]) -> tuple[torch.Tensor, tuple]: """ If the batch is an EdgeBatch, compute the loss and return it along with the logits and labels. If the batch is a NodeBatch, infer node embeddings instead. Args: batch (Union[EdgeBatch, NodeBatch]): The input batch, which can be either an EdgeBatch or a NodeBatch. Returns: Tuple[torch.Tensor, Tuple]: A tuple containing the loss and a tuple of (loss, logits, labels) for EdgeBatch, or (dummy_loss, node_embeddings) for NodeBatch. """ if ( self.phase == ModelPhase.INFERENCE_SRC or self.phase == ModelPhase.INFERENCE_DST ): # In inference phase, we only handle NodeBatch. assert isinstance( batch, NodeBatch ), "Batch must be a NodeBatch in inference phase." node_embeddings = self.encoder_model.infer_node_batch(node_batch=batch) dummy_loss = torch.tensor(0.0) node_ids = batch.nodes.values() return dummy_loss, ( node_ids.detach(), node_embeddings.detach(), ) else: # We expect an EdgeBatch in training, validation, or testing phases. assert isinstance( batch, EdgeBatch ), "Batch must be an EdgeBatch in training, validation, or testing phases." logits: torch.Tensor labels: torch.Tensor logits, labels, condensed_edge_types = self.encoder_model(edge_batch=batch) loss = self.loss_fn(logits, labels) return loss, ( loss.detach(), logits.detach(), labels.detach(), condensed_edge_types.detach(), )