gigl.experimental.knowledge_graph_embedding.lib.model.heterogeneous_graph_model#
Attributes#
Classes#
| A backbone model to support sparse embedding of (possibly multi-relational) graphs. | |
| A simple heterogeneous information network model with loss. This module | 
Module Contents#
- class gigl.experimental.knowledge_graph_embedding.lib.model.heterogeneous_graph_model.HeterogeneousGraphSparseEmbeddingModel(model_config)[source]#
- Bases: - torch.nn.Module- A backbone model to support sparse embedding of (possibly multi-relational) graphs. Can also be used to implement matrix factorization and variants. - Useful overviews on Knowledge Graph Embedding:
- Knowledge Graph Embedding: An Overview (Ge et al, 2023): https://arxiv.org/pdf/2309.12501 
- Stanford CS224W: ML with Graphs: Knowledge Graph Embeddings (2023): https://www.youtube.com/watch?v=isI_TUMoP60 
 
 - Parameters:
- model_config (ModelConfig) – Configuration object containing model parameters. 
 - Initialize the model with the given embedding configurations. - apply_relation_operator(src_embeddings, dst_embeddings, condensed_edge_types)[source]#
- Apply the src and dst relation operators to the source and destination embeddings. - Some reasonable configurations to reimplement common KG embedding algorithms: - TransE: Translation on src embeddings, dst embeddings remain unchanged - CompleX: Complex diagonal on src embeddings, dst embeddings remain unchanged - DistMult: Diagonal on src embeddings, dst embeddings remain unchanged - RESCAL: Linear on src embeddings, dst embeddings remain unchanged - This can also be used to implement things like raw Matrix Factorization by using identity operators, or other custom operators. - Parameters:
- src_embeddings (torch.Tensor) – Source node embeddings. 
- dst_embeddings (torch.Tensor) – Destination node embeddings. 
- condensed_edge_types (torch.Tensor) – Edge types for the current batch. 
 
- Returns:
- Tuple of transformed source and destination embeddings. 
- Return type:
- tuple[torch.Tensor, torch.Tensor] 
 
 - fetch_src_and_dst_embeddings(edge_batch)[source]#
- Parameters:
- edge_batch (gigl.experimental.knowledge_graph_embedding.lib.data.edge_batch.EdgeBatch) 
- Return type:
- tuple[torch.Tensor, torch.Tensor] 
 
 - forward(edge_batch)[source]#
- Parameters:
- edge_batch (gigl.experimental.knowledge_graph_embedding.lib.data.edge_batch.EdgeBatch) 
- Return type:
- tuple[torch.Tensor, torch.Tensor, torch.Tensor] 
 
 - infer_node_batch(node_batch)[source]#
- Infer node embeddings for a given NodeBatch. - Parameters:
- node_batch (NodeBatch) – The batch of nodes to infer embeddings for. 
- Returns:
- The inferred node embeddings. 
- Return type:
- torch.Tensor 
 
 - score_edges(src_embeddings, dst_embeddings)[source]#
- Parameters:
- src_embeddings (torch.Tensor) 
- dst_embeddings (torch.Tensor) 
 
 
 - set_phase(phase)[source]#
- Set the phase of the model. This is used to determine which sampling configuration to use during training, validation, or testing. - Note that this affects (i) how data that is passed into the model is interpreted (e.g. #s of positives, negatives) (ii) whether inbatch negatives are used to compute logits and labels - Parameters:
- phase (ModelPhase) – The current phase of the model (TRAIN, VALIDATION, TEST). 
 
 - property active_sampling_config: gigl.experimental.knowledge_graph_embedding.lib.config.sampling.SamplingConfig[source]#
 
- class gigl.experimental.knowledge_graph_embedding.lib.model.heterogeneous_graph_model.HeterogeneousGraphSparseEmbeddingModelAndLoss(encoder_model, loss_fn=F.binary_cross_entropy_with_logits)[source]#
- Bases: - torch.nn.Module- A simple heterogeneous information network model with loss. This module wraps the HeterogeneousGraphSparseEmbeddingModel model for use with torchrec TrainPipeline abstraction, which requires specific input/output expectations regarding loss and outputs. This is required by TorchRec’s convention. For more details, see: - Initialize the model with the given encoder model and loss function. - Parameters:
- encoder_model (HeterogeneousGraphSparseEmbeddingModel) – The underlying model for encoding. 
- loss_fn (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) – The loss function to compute the loss. Defaults to binary cross-entropy with logits. 
 
 - forward(batch)[source]#
- If the batch is an EdgeBatch, compute the loss and return it along with the logits and labels. - If the batch is a NodeBatch, infer node embeddings instead. 
 - set_phase(phase)[source]#
- Set the phase of the encoder model. This is used to determine which sampling configuration to use during training, validation, or testing. - Note that this affects (i) how data that is passed into the model is interpreted (e.g. #s of positives, negatives) (ii) whether inbatch negatives are used to compute logits and labels - Parameters:
- phase (ModelPhase) – The current phase of the model (TRAIN, VAL, TEST, INFERENCE). 
 
 
