gigl.experimental.knowledge_graph_embedding.lib.config.training#

Classes#

CheckpointingConfig

Configuration for model checkpointing during training.

DistributedConfig

Configuration for distributed training across multiple GPUs or processes.

EarlyStoppingConfig

Configuration for early stopping based on validation performance.

LoggingConfig

Configuration for training progress logging.

OptimizerConfig

Configuration for separate optimizers for sparse and dense parameters.

OptimizerParamsConfig

Configuration for optimizer hyperparameters.

TrainConfig

Main training configuration that orchestrates all training-related settings.

Module Contents#

class gigl.experimental.knowledge_graph_embedding.lib.config.training.CheckpointingConfig[source]#

Configuration for model checkpointing during training.

save_every[source]#

Save a checkpoint every N training steps. Allows recovery from failures and monitoring of training progress. Defaults to 10,000 steps.

Type:

int

should_save_async[source]#

Whether to save checkpoints asynchronously to avoid blocking training. Improves training efficiency but may use additional memory. Defaults to True.

Type:

bool

load_from_path[source]#

Path to a checkpoint file to resume training from. If None, training starts from scratch. Defaults to None.

Type:

Optional[str]

save_to_path[source]#

Directory path where checkpoints will be saved. If None, checkpoints are not saved. Defaults to None.

Type:

Optional[str]

load_from_path: str | None = None[source]#
save_every: int = 10000[source]#
save_to_path: str | None = None[source]#
should_save_async: bool = True[source]#
class gigl.experimental.knowledge_graph_embedding.lib.config.training.DistributedConfig[source]#

Configuration for distributed training across multiple GPUs or processes.

num_processes_per_machine[source]#

Number of training processes to spawn per machine. Each process typically uses one GPU. Defaults to torch.cuda.device_count() if CUDA is available, otherwise 1.

Type:

int

storage_reservation_percentage[source]#

Storage percentage buffer used by TorchRec. to account for overhead on dense tensor and KJT storage. Defaults to 0.1 (10%).

Type:

float

num_processes_per_machine: int = 1[source]#
storage_reservation_percentage: float = 0.1[source]#
class gigl.experimental.knowledge_graph_embedding.lib.config.training.EarlyStoppingConfig[source]#

Configuration for early stopping based on validation performance.

patience[source]#

Number of evaluation steps to wait for improvement before stopping training. Helps prevent overfitting by stopping when validation performance plateaus. If None, early stopping is disabled. Defaults to None.

Type:

Optional[int]

patience: int | None = None[source]#
class gigl.experimental.knowledge_graph_embedding.lib.config.training.LoggingConfig[source]#

Configuration for training progress logging.

log_every[source]#

Log training metrics every N steps. More frequent logging provides better monitoring but may slow down training. Defaults to 1 (log every step).

Type:

int

log_every: int = 1[source]#
class gigl.experimental.knowledge_graph_embedding.lib.config.training.OptimizerConfig[source]#

Configuration for separate optimizers for sparse and dense parameters.

Knowledge graph embedding models typically have both sparse embeddings (updated only for nodes/edges in each batch) and dense parameters (updated every batch). Different learning rates are often beneficial for these parameter types.

sparse[source]#

Optimizer parameters for sparse embeddings (for nodes). Defaults to OptimizerParamsConfig(lr=0.01, weight_decay=0.001).

Type:

OptimizerParamsConfig

dense[source]#

Optimizer parameters for dense model parameters (linear layers, etc.). Defaults to OptimizerParamsConfig(lr=0.01, weight_decay=0.001).

Type:

OptimizerParamsConfig

dense: OptimizerParamsConfig[source]#
sparse: OptimizerParamsConfig[source]#
class gigl.experimental.knowledge_graph_embedding.lib.config.training.OptimizerParamsConfig[source]#

Configuration for optimizer hyperparameters.

lr[source]#

Learning rate for the optimizer. Controls the step size during gradient descent. Higher values lead to faster convergence but may overshoot the minimum. Defaults to 0.001.

Type:

float

weight_decay[source]#

L2 regularization coefficient applied to model parameters. Helps prevent overfitting by penalizing large weights. Defaults to 0.001.

Type:

float

lr: float = 0.001[source]#
weight_decay: float = 0.001[source]#
class gigl.experimental.knowledge_graph_embedding.lib.config.training.TrainConfig[source]#

Main training configuration that orchestrates all training-related settings.

This configuration combines optimization, data loading, distributed training, checkpointing, and monitoring settings for knowledge graph embedding training.

max_steps[source]#

Maximum number of training steps to perform. If None, training continues until early stopping or manual interruption. Defaults to None.

Type:

Optional[int]

early_stopping[source]#

Configuration for early stopping based on validation metrics. Defaults to EarlyStoppingConfig() with no patience limit.

Type:

EarlyStoppingConfig

dataloader[source]#

Configuration for data loading (number of workers, memory pinning). Defaults to DataloaderConfig() with standard settings.

Type:

DataloaderConfig

sampling[source]#

Configuration for negative sampling strategy during training. Defaults to SamplingConfig() with standard settings.

Type:

SamplingConfig

optimizer[source]#

Configuration for separate sparse and dense optimizers. Defaults to OptimizerConfig() with standard settings.

Type:

OptimizerConfig

distributed[source]#

Configuration for multi-GPU/multi-process training. Defaults to DistributedConfig() with auto-detected GPU count.

Type:

DistributedConfig

checkpointing[source]#

Configuration for saving and loading model checkpoints. Defaults to CheckpointingConfig() with standard settings.

Type:

CheckpointingConfig

logging[source]#

Configuration for training progress logging frequency. Defaults to LoggingConfig() with log-every-step setting.

Type:

LoggingConfig

checkpointing: CheckpointingConfig[source]#
dataloader: gigl.experimental.knowledge_graph_embedding.lib.config.dataloader.DataloaderConfig[source]#
distributed: DistributedConfig[source]#
early_stopping: EarlyStoppingConfig[source]#
logging: LoggingConfig[source]#
max_steps: int | None = None[source]#
optimizer: OptimizerConfig[source]#
sampling: gigl.experimental.knowledge_graph_embedding.lib.config.sampling.SamplingConfig[source]#