Source code for gigl.experimental.knowledge_graph_embedding.lib.config.evaluation
from dataclasses import dataclass, field
from typing import List, Optional
from gigl.experimental.knowledge_graph_embedding.lib.config.dataloader import (
    DataloaderConfig,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.sampling import (
    SamplingConfig,
)
@dataclass
[docs]
class EvaluationPhaseConfig:
    """
    Configuration for evaluation phases (validation/testing) during knowledge graph embedding training.
    Controls how model performance is measured during training (validation phase) and after
    training completion (testing phase). Uses ranking-based metrics to assess link prediction quality.
    Attributes:
        dataloader (DataloaderConfig): Configuration for data loading during evaluation (workers, memory pinning).
            Defaults to DataloaderConfig() with standard settings.
        step_frequency (Optional[int]): How often to run evaluation during training (every N steps).
            If None, evaluation runs only at the end of training. Defaults to None.
        num_batches (Optional[int]): Maximum number of batches to evaluate. Useful for faster evaluation
            on large datasets by sampling a subset. If None, evaluates all data. Defaults to None.
        hit_rates_at_k (List[int]): List of k values for computing Hit@k (Hits at k) metrics.
            Hit@k measures if the correct answer appears in the top k predictions.
            Common values are [1, 10, 100]. Defaults to [1, 10, 100].
        sampling (SamplingConfig): Negative sampling configuration for evaluation. Should match or be
            compatible with training sampling to ensure fair comparison.
            Defaults to SamplingConfig() with standard settings.
    """
[docs]
    dataloader: DataloaderConfig = field(default_factory=DataloaderConfig) 
[docs]
    step_frequency: Optional[int] = None 
[docs]
    num_batches: Optional[int] = None 
[docs]
    hit_rates_at_k: List[int] = field(default_factory=lambda: [1, 10, 100]) 
[docs]
    sampling: SamplingConfig = field(default_factory=SamplingConfig) 
    def __post_init__(self) -> None:
        """
        Post-initialization validation of evaluation configuration parameters.
        Validates that the total number of negative samples (random + in-batch) is
        sufficient to compute the requested Hit@k metrics. This ensures that evaluation
        can meaningfully compute metrics for all requested k values.
        Raises:
            ValueError: If max(hit_rates_at_k) exceeds the total number of negative samples
                available for ranking (num_random_negatives_per_edge + num_inbatch_negatives_per_edge).
        """
        if max(self.hit_rates_at_k) > (
            self.sampling.num_random_negatives_per_edge
            + self.sampling.num_inbatch_negatives_per_edge
        ):
            raise ValueError(
                f"""Validation `num_random_negatives_per_edge` + `num_inbatch_negatives_per_edge` must be >= max(hit_rates_at_k).
                Got ({self.sampling.num_random_negatives_per_edge} + {self.sampling.num_inbatch_negatives_per_edge}) and {max(self.hit_rates_at_k)}"""
            )