Source code for gigl.experimental.knowledge_graph_embedding.lib.model.operators

import torch
import torch.nn as nn


[docs] class RelationwiseOperatorBase(nn.Module): """ Base class for relationwise operators in heterogeneous graph embeddings. Each operator applies a transformation to the node embeddings based on the context of a specific relation / edge-type. """ def __init__(self, num_edge_types: int, node_emb_dim: int): super().__init__()
[docs] def forward( self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor ) -> torch.Tensor: raise NotImplementedError
[docs] class TranslationOperator(RelationwiseOperatorBase): """ A translation operator for heterogeneous graph embeddings. This operator adds the edge type embeddings to the node embeddings. It is used to model the relationship between nodes in a heterogeneous graph. The edge type embeddings are learned during training and are used to represent the different types of relationships between nodes. See https://papers.nips.cc/paper_files/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf """ def __init__(self, num_edge_types: int, node_emb_dim: int): super().__init__(num_edge_types=num_edge_types, node_emb_dim=node_emb_dim)
[docs] self.edge_type_embeddings = nn.Embedding( num_embeddings=num_edge_types, embedding_dim=node_emb_dim, )
[docs] def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor): edge_type_embeddings = self.edge_type_embeddings(condensed_edge_types) return embeddings + edge_type_embeddings
[docs] class DiagonalOperator(RelationwiseOperatorBase): """ A diagonal operator for heterogeneous graph embeddings. This operator multiplies the node embeddings by the edge type embeddings. """ def __init__(self, num_edge_types: int, node_emb_dim: int): super().__init__(num_edge_types=num_edge_types, node_emb_dim=node_emb_dim)
[docs] self.edge_type_embeddings = nn.Embedding( num_embeddings=num_edge_types, embedding_dim=node_emb_dim, )
[docs] def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor): edge_type_embeddings = self.edge_type_embeddings(condensed_edge_types) return embeddings * edge_type_embeddings
[docs] class ComplexDiagonalOperator(RelationwiseOperatorBase): """ A complex diagonal operator for heterogeneous graph embeddings. This operator splits the node embeddings into real and imaginary parts, and then applies a diagonal operator to each part separately. The edge type embeddings are also split into real and imaginary parts. See https://proceedings.mlr.press/v48/trouillon16.pdf. """ def __init__(self, num_edge_types: int, node_emb_dim: int): super().__init__(num_edge_types=num_edge_types, node_emb_dim=node_emb_dim) if node_emb_dim % 2 != 0: raise ValueError("Complex embeddings require an even embedding dimension.")
[docs] self.edge_type_embeddings = nn.Embedding( num_embeddings=num_edge_types, embedding_dim=node_emb_dim, )
[docs] def real_part(self, embeddings: torch.Tensor): return embeddings[:, : embeddings.shape[1] // 2]
[docs] def imag_part(self, embeddings: torch.Tensor): return embeddings[:, embeddings.shape[1] // 2 :]
[docs] def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor): edge_type_embeddings = self.edge_type_embeddings(condensed_edge_types) # Split the embeddings into real and imaginary parts src_embeddings_real = self.real_part(embeddings) src_embeddings_imag = self.imag_part(embeddings) edge_type_embeddings_real = self.real_part(edge_type_embeddings) edge_type_embeddings_imag = self.imag_part(edge_type_embeddings) # Apply the complex diagonal operator # Following eq10 here: https://proceedings.mlr.press/v48/trouillon16.pdf first = ( edge_type_embeddings_real * src_embeddings_real - edge_type_embeddings_imag * src_embeddings_imag ) second = ( edge_type_embeddings_real * src_embeddings_imag + edge_type_embeddings_imag * src_embeddings_real ) return torch.cat((first, second), dim=1)
[docs] class LinearOperator(RelationwiseOperatorBase): """ A linear operator for heterogeneous graph embeddings. This operator projects the node embeddings using a learned projection matrix for each edge type. The projection matrix is learned during training and is used to represent the different types of relationships between nodes. """ def __init__(self, num_edge_types: int, node_emb_dim: int): super().__init__(num_edge_types=num_edge_types, node_emb_dim=node_emb_dim)
[docs] self.edge_type_projection = nn.Parameter( torch.empty(num_edge_types, node_emb_dim, node_emb_dim), )
nn.init.xavier_normal_(self.edge_type_projection)
[docs] def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor): return ( embeddings @ self.edge_type_projection ) # [num_edge_types, batch_size, node_emb_dim]
[docs] class IdentityOperator(RelationwiseOperatorBase): """ An identity operator for heterogeneous graph embeddings. This operator does not apply any transformation to the node embeddings. It is used when no relation operator is needed. """
[docs] def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor): return embeddings