Source code for gigl.experimental.knowledge_graph_embedding.lib.config.model
from dataclasses import dataclass
from typing import List, Optional
import torchrec
from gigl.experimental.knowledge_graph_embedding.lib.config.sampling import (
    SamplingConfig,
)
from gigl.experimental.knowledge_graph_embedding.lib.model.types import (
    OperatorType,
    SimilarityType,
)
@dataclass
[docs]
class ModelConfig:
    """
    Configuration for knowledge graph embedding model architecture.
    Defines the structure and behavior of the embedding model used for link prediction
    in heterogeneous knowledge graphs.
    Attributes:
        node_embedding_dim (int): Dimensionality of node embeddings. Higher dimensions can
            capture more complex relationships but require more memory and computation.
            Defaults to 128.
        embedding_similarity_type (SimilarityType): Type of similarity function used to compute scores
            between node embeddings. Options include cosine similarity, dot product, etc.
            Defaults to SimilarityType.COSINE.
        src_operator (OperatorType): Transformation operator applied to source node embeddings before
            computing edge scores. Can be identity (no transformation) or learned operators.
            Defaults to OperatorType.IDENTITY.
        dst_operator (OperatorType): Transformation operator applied to destination node embeddings
            before computing edge scores. Can be identity (no transformation) or learned operators.
            Defaults to OperatorType.IDENTITY.
        training_sampling (Optional[SamplingConfig]): Sampling configuration used during training phase.
            Populated at runtime from training config. Defaults to None.
        validation_sampling (Optional[SamplingConfig]): Sampling configuration used during validation phase.
            Populated at runtime from validation config. Defaults to None.
        testing_sampling (Optional[SamplingConfig]): Sampling configuration used during testing phase.
            Populated at runtime from testing config. Defaults to None.
        num_edge_types (Optional[int]): Number of distinct edge types in the knowledge graph.
            Populated at runtime from graph metadata. Defaults to None.
        embeddings_config (Optional[List[torchrec.EmbeddingBagConfig]]): TorchRec embedding configuration for sparse embeddings.
            Specifies embedding tables, sharding strategies, and optimization settings.
            Populated at runtime. Defaults to None.
    """
[docs]
    node_embedding_dim: int = 128 
[docs]
    embedding_similarity_type: SimilarityType = SimilarityType.COSINE 
[docs]
    src_operator: OperatorType = OperatorType.IDENTITY 
[docs]
    dst_operator: OperatorType = OperatorType.IDENTITY 
    # Below fields are populated at runtime.
[docs]
    training_sampling: Optional[SamplingConfig] = None 
[docs]
    validation_sampling: Optional[SamplingConfig] = None 
[docs]
    testing_sampling: Optional[SamplingConfig] = None 
[docs]
    num_edge_types: Optional[int] = None 
[docs]
    embeddings_config: Optional[List[torchrec.EmbeddingBagConfig]] = None