Source code for gigl.experimental.knowledge_graph_embedding.lib.config.evaluation
from dataclasses import dataclass, field
from typing import List, Optional
from gigl.experimental.knowledge_graph_embedding.lib.config.dataloader import (
DataloaderConfig,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.sampling import (
SamplingConfig,
)
@dataclass
[docs]
class EvaluationPhaseConfig:
"""
Configuration for evaluation phases (validation/testing) during knowledge graph embedding training.
Controls how model performance is measured during training (validation phase) and after
training completion (testing phase). Uses ranking-based metrics to assess link prediction quality.
Attributes:
dataloader (DataloaderConfig): Configuration for data loading during evaluation (workers, memory pinning).
Defaults to DataloaderConfig() with standard settings.
step_frequency (Optional[int]): How often to run evaluation during training (every N steps).
If None, evaluation runs only at the end of training. Defaults to None.
num_batches (Optional[int]): Maximum number of batches to evaluate. Useful for faster evaluation
on large datasets by sampling a subset. If None, evaluates all data. Defaults to None.
hit_rates_at_k (List[int]): List of k values for computing Hit@k (Hits at k) metrics.
Hit@k measures if the correct answer appears in the top k predictions.
Common values are [1, 10, 100]. Defaults to [1, 10, 100].
sampling (SamplingConfig): Negative sampling configuration for evaluation. Should match or be
compatible with training sampling to ensure fair comparison.
Defaults to SamplingConfig() with standard settings.
"""
[docs]
dataloader: DataloaderConfig = field(default_factory=DataloaderConfig)
[docs]
step_frequency: Optional[int] = None
[docs]
num_batches: Optional[int] = None
[docs]
hit_rates_at_k: List[int] = field(default_factory=lambda: [1, 10, 100])
[docs]
sampling: SamplingConfig = field(default_factory=SamplingConfig)
def __post_init__(self) -> None:
"""
Post-initialization validation of evaluation configuration parameters.
Validates that the total number of negative samples (random + in-batch) is
sufficient to compute the requested Hit@k metrics. This ensures that evaluation
can meaningfully compute metrics for all requested k values.
Raises:
ValueError: If max(hit_rates_at_k) exceeds the total number of negative samples
available for ranking (num_random_negatives_per_edge + num_inbatch_negatives_per_edge).
"""
if max(self.hit_rates_at_k) > (
self.sampling.num_random_negatives_per_edge
+ self.sampling.num_inbatch_negatives_per_edge
):
raise ValueError(
f"""Validation `num_random_negatives_per_edge` + `num_inbatch_negatives_per_edge` must be >= max(hit_rates_at_k).
Got ({self.sampling.num_random_negatives_per_edge} + {self.sampling.num_inbatch_negatives_per_edge}) and {max(self.hit_rates_at_k)}"""
)