Source code for gigl.experimental.knowledge_graph_embedding.lib.config.training
from dataclasses import dataclass, field
from typing import Optional
import torch
from gigl.experimental.knowledge_graph_embedding.lib.config.dataloader import (
DataloaderConfig,
)
from gigl.experimental.knowledge_graph_embedding.lib.config.sampling import (
SamplingConfig,
)
@dataclass
[docs]
class OptimizerParamsConfig:
"""
Configuration for optimizer hyperparameters.
Attributes:
lr (float): Learning rate for the optimizer. Controls the step size during gradient descent.
Higher values lead to faster convergence but may overshoot the minimum.
Defaults to 0.001.
weight_decay (float): L2 regularization coefficient applied to model parameters.
Helps prevent overfitting by penalizing large weights. Defaults to 0.001.
"""
[docs]
weight_decay: float = (
0.001 # TODO(nshah): consider supporting weight decay for sparse embeddings.
)
@dataclass
[docs]
class OptimizerConfig:
"""
Configuration for separate optimizers for sparse and dense parameters.
Knowledge graph embedding models typically have both sparse embeddings (updated only
for nodes/edges in each batch) and dense parameters (updated every batch). Different
learning rates are often beneficial for these parameter types.
Attributes:
sparse (OptimizerParamsConfig): Optimizer parameters for sparse embeddings (for nodes).
Defaults to OptimizerParamsConfig(lr=0.01, weight_decay=0.001).
dense (OptimizerParamsConfig): Optimizer parameters for dense model parameters (linear layers, etc.).
Defaults to OptimizerParamsConfig(lr=0.01, weight_decay=0.001).
"""
[docs]
sparse: OptimizerParamsConfig = field(
default_factory=lambda: OptimizerParamsConfig(lr=0.01, weight_decay=0.001)
)
[docs]
dense: OptimizerParamsConfig = field(
default_factory=lambda: OptimizerParamsConfig(lr=0.01, weight_decay=0.001)
)
@dataclass
[docs]
class DistributedConfig:
"""
Configuration for distributed training across multiple GPUs or processes.
Attributes:
num_processes_per_machine (int): Number of training processes to spawn per machine.
Each process typically uses one GPU. Defaults to torch.cuda.device_count()
if CUDA is available, otherwise 1.
storage_reservation_percentage (float): Storage percentage buffer used by TorchRec.
to account for overhead on dense tensor and KJT storage. Defaults to 0.1 (10%).
"""
[docs]
num_processes_per_machine: int = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
[docs]
storage_reservation_percentage: float = 0.1
@dataclass
[docs]
class CheckpointingConfig:
"""
Configuration for model checkpointing during training.
Attributes:
save_every (int): Save a checkpoint every N training steps. Allows recovery from
failures and monitoring of training progress. Defaults to 10,000 steps.
should_save_async (bool): Whether to save checkpoints asynchronously to avoid blocking
training. Improves training efficiency but may use additional memory.
Defaults to True.
load_from_path (Optional[str]): Path to a checkpoint file to resume training from. If None,
training starts from scratch. Defaults to None.
save_to_path (Optional[str]): Directory path where checkpoints will be saved. If None,
checkpoints are not saved. Defaults to None.
"""
[docs]
save_every: int = 10_000
[docs]
should_save_async: bool = True
[docs]
load_from_path: Optional[str] = None
[docs]
save_to_path: Optional[str] = None
@dataclass
[docs]
class LoggingConfig:
"""
Configuration for training progress logging.
Attributes:
log_every (int): Log training metrics every N steps. More frequent logging provides
better monitoring but may slow down training. Defaults to 1 (log every step).
"""
@dataclass
[docs]
class EarlyStoppingConfig:
"""
Configuration for early stopping based on validation performance.
Attributes:
patience (Optional[int]): Number of evaluation steps to wait for improvement before stopping
training. Helps prevent overfitting by stopping when validation performance
plateaus. If None, early stopping is disabled. Defaults to None.
"""
[docs]
patience: Optional[int] = None
@dataclass
[docs]
class TrainConfig:
"""
Main training configuration that orchestrates all training-related settings.
This configuration combines optimization, data loading, distributed training,
checkpointing, and monitoring settings for knowledge graph embedding training.
Attributes:
max_steps (Optional[int]): Maximum number of training steps to perform. If None, training
continues until early stopping or manual interruption. Defaults to None.
early_stopping (EarlyStoppingConfig): Configuration for early stopping based on validation metrics.
Defaults to EarlyStoppingConfig() with no patience limit.
dataloader (DataloaderConfig): Configuration for data loading (number of workers, memory pinning).
Defaults to DataloaderConfig() with standard settings.
sampling (SamplingConfig): Configuration for negative sampling strategy during training.
Defaults to SamplingConfig() with standard settings.
optimizer (OptimizerConfig): Configuration for separate sparse and dense optimizers.
Defaults to OptimizerConfig() with standard settings.
distributed (DistributedConfig): Configuration for multi-GPU/multi-process training.
Defaults to DistributedConfig() with auto-detected GPU count.
checkpointing (CheckpointingConfig): Configuration for saving and loading model checkpoints.
Defaults to CheckpointingConfig() with standard settings.
logging (LoggingConfig): Configuration for training progress logging frequency.
Defaults to LoggingConfig() with log-every-step setting.
"""
[docs]
max_steps: Optional[int] = None
[docs]
early_stopping: EarlyStoppingConfig = field(default_factory=EarlyStoppingConfig)
[docs]
dataloader: DataloaderConfig = field(default_factory=DataloaderConfig)
[docs]
sampling: SamplingConfig = field(default_factory=SamplingConfig)
[docs]
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
[docs]
distributed: DistributedConfig = field(default_factory=DistributedConfig)
[docs]
checkpointing: CheckpointingConfig = field(default_factory=CheckpointingConfig)
[docs]
logging: LoggingConfig = field(default_factory=LoggingConfig)