Source code for gigl.experimental.knowledge_graph_embedding.lib.model.operators
import torch
import torch.nn as nn
[docs]
class RelationwiseOperatorBase(nn.Module):
    """
    Base class for relationwise operators in heterogeneous graph embeddings.
    Each operator applies a transformation to the node embeddings based on
    the context of a specific relation / edge-type.
    """
    def __init__(self, num_edge_types: int, node_emb_dim: int):
        super().__init__()
[docs]
    def forward(
        self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor
    ) -> torch.Tensor:
        raise NotImplementedError 
 
[docs]
class TranslationOperator(RelationwiseOperatorBase):
    """
    A translation operator for heterogeneous graph embeddings.
    This operator adds the edge type embeddings to the node embeddings.
    It is used to model the relationship between nodes in a heterogeneous graph.
    The edge type embeddings are learned during training and are used to
    represent the different types of relationships between nodes.
    See https://papers.nips.cc/paper_files/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf
    """
    def __init__(self, num_edge_types: int, node_emb_dim: int):
        super().__init__(num_edge_types=num_edge_types, node_emb_dim=node_emb_dim)
[docs]
        self.edge_type_embeddings = nn.Embedding(
            num_embeddings=num_edge_types,
            embedding_dim=node_emb_dim,
        ) 
[docs]
    def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor):
        edge_type_embeddings = self.edge_type_embeddings(condensed_edge_types)
        return embeddings + edge_type_embeddings 
 
[docs]
class DiagonalOperator(RelationwiseOperatorBase):
    """
    A diagonal operator for heterogeneous graph embeddings.
    This operator multiplies the node embeddings by the edge type embeddings.
    """
    def __init__(self, num_edge_types: int, node_emb_dim: int):
        super().__init__(num_edge_types=num_edge_types, node_emb_dim=node_emb_dim)
[docs]
        self.edge_type_embeddings = nn.Embedding(
            num_embeddings=num_edge_types,
            embedding_dim=node_emb_dim,
        ) 
[docs]
    def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor):
        edge_type_embeddings = self.edge_type_embeddings(condensed_edge_types)
        return embeddings * edge_type_embeddings 
 
[docs]
class ComplexDiagonalOperator(RelationwiseOperatorBase):
    """
    A complex diagonal operator for heterogeneous graph embeddings.
    This operator splits the node embeddings into real and imaginary parts,
    and then applies a diagonal operator to each part separately.
    The edge type embeddings are also split into real and imaginary parts.
    See https://proceedings.mlr.press/v48/trouillon16.pdf.
    """
    def __init__(self, num_edge_types: int, node_emb_dim: int):
        super().__init__(num_edge_types=num_edge_types, node_emb_dim=node_emb_dim)
        if node_emb_dim % 2 != 0:
            raise ValueError("Complex embeddings require an even embedding dimension.")
[docs]
        self.edge_type_embeddings = nn.Embedding(
            num_embeddings=num_edge_types,
            embedding_dim=node_emb_dim,
        ) 
[docs]
    def real_part(self, embeddings: torch.Tensor):
        return embeddings[:, : embeddings.shape[1] // 2] 
[docs]
    def imag_part(self, embeddings: torch.Tensor):
        return embeddings[:, embeddings.shape[1] // 2 :] 
[docs]
    def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor):
        edge_type_embeddings = self.edge_type_embeddings(condensed_edge_types)
        # Split the embeddings into real and imaginary parts
        src_embeddings_real = self.real_part(embeddings)
        src_embeddings_imag = self.imag_part(embeddings)
        edge_type_embeddings_real = self.real_part(edge_type_embeddings)
        edge_type_embeddings_imag = self.imag_part(edge_type_embeddings)
        # Apply the complex diagonal operator
        # Following eq10 here: https://proceedings.mlr.press/v48/trouillon16.pdf
        first = (
            edge_type_embeddings_real * src_embeddings_real
            - edge_type_embeddings_imag * src_embeddings_imag
        )
        second = (
            edge_type_embeddings_real * src_embeddings_imag
            + edge_type_embeddings_imag * src_embeddings_real
        )
        return torch.cat((first, second), dim=1) 
 
[docs]
class LinearOperator(RelationwiseOperatorBase):
    """
    A linear operator for heterogeneous graph embeddings.
    This operator projects the node embeddings using a learned projection matrix
    for each edge type. The projection matrix is learned during training and
    is used to represent the different types of relationships between nodes.
    """
    def __init__(self, num_edge_types: int, node_emb_dim: int):
        super().__init__(num_edge_types=num_edge_types, node_emb_dim=node_emb_dim)
[docs]
        self.edge_type_projection = nn.Parameter(
            torch.empty(num_edge_types, node_emb_dim, node_emb_dim),
        ) 
        nn.init.xavier_normal_(self.edge_type_projection)
[docs]
    def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor):
        return (
            embeddings @ self.edge_type_projection
        )  # [num_edge_types, batch_size, node_emb_dim] 
 
[docs]
class IdentityOperator(RelationwiseOperatorBase):
    """
    An identity operator for heterogeneous graph embeddings.
    This operator does not apply any transformation to the node embeddings.
    It is used when no relation operator is needed.
    """
[docs]
    def forward(self, embeddings: torch.Tensor, condensed_edge_types: torch.Tensor):
        return embeddings