Source code for gigl.experimental.knowledge_graph_embedding.lib.config.model

from dataclasses import dataclass
from typing import List, Optional

import torchrec

from gigl.experimental.knowledge_graph_embedding.lib.config.sampling import (
    SamplingConfig,
)
from gigl.experimental.knowledge_graph_embedding.lib.model.types import (
    OperatorType,
    SimilarityType,
)


@dataclass
[docs] class ModelConfig: """ Configuration for knowledge graph embedding model architecture. Defines the structure and behavior of the embedding model used for link prediction in heterogeneous knowledge graphs. Attributes: node_embedding_dim (int): Dimensionality of node embeddings. Higher dimensions can capture more complex relationships but require more memory and computation. Defaults to 128. embedding_similarity_type (SimilarityType): Type of similarity function used to compute scores between node embeddings. Options include cosine similarity, dot product, etc. Defaults to SimilarityType.COSINE. src_operator (OperatorType): Transformation operator applied to source node embeddings before computing edge scores. Can be identity (no transformation) or learned operators. Defaults to OperatorType.IDENTITY. dst_operator (OperatorType): Transformation operator applied to destination node embeddings before computing edge scores. Can be identity (no transformation) or learned operators. Defaults to OperatorType.IDENTITY. training_sampling (Optional[SamplingConfig]): Sampling configuration used during training phase. Populated at runtime from training config. Defaults to None. validation_sampling (Optional[SamplingConfig]): Sampling configuration used during validation phase. Populated at runtime from validation config. Defaults to None. testing_sampling (Optional[SamplingConfig]): Sampling configuration used during testing phase. Populated at runtime from testing config. Defaults to None. num_edge_types (Optional[int]): Number of distinct edge types in the knowledge graph. Populated at runtime from graph metadata. Defaults to None. embeddings_config (Optional[List[torchrec.EmbeddingBagConfig]]): TorchRec embedding configuration for sparse embeddings. Specifies embedding tables, sharding strategies, and optimization settings. Populated at runtime. Defaults to None. """
[docs] node_embedding_dim: int = 128
[docs] embedding_similarity_type: SimilarityType = SimilarityType.COSINE
[docs] src_operator: OperatorType = OperatorType.IDENTITY
[docs] dst_operator: OperatorType = OperatorType.IDENTITY
# Below fields are populated at runtime.
[docs] training_sampling: Optional[SamplingConfig] = None
[docs] validation_sampling: Optional[SamplingConfig] = None
[docs] testing_sampling: Optional[SamplingConfig] = None
[docs] num_edge_types: Optional[int] = None
[docs] embeddings_config: Optional[List[torchrec.EmbeddingBagConfig]] = None