gigl.src.training.v1.lib.training_process#

Attributes#

Classes#

Functions#

generate_trainer_instance(gbml_config_pb_wrapper)

get_torch_profiler_instance(gbml_config_pb_wrapper)

save_model(trainer, gbml_config_pb_wrapper)

setup_model_device(model, ...)

Configures the model by setting it on device, syncing batch norm, and wrapping the model with DDP with the relevant flags, such as find_unused_parameters

Module Contents#

class gigl.src.training.v1.lib.training_process.GnnTrainingProcess[source]#
run(task_config_uri, device)[source]#
Parameters:
gigl.src.training.v1.lib.training_process.generate_trainer_instance(gbml_config_pb_wrapper)[source]#
Parameters:

gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper)

Return type:

gigl.src.training.v1.lib.base_trainer.BaseTrainer

gigl.src.training.v1.lib.training_process.get_torch_profiler_instance(gbml_config_pb_wrapper)[source]#
Parameters:

gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper)

Return type:

Optional[gigl.src.common.modeling_task_specs.utils.profiler_wrapper.TorchProfiler]

gigl.src.training.v1.lib.training_process.save_model(trainer, gbml_config_pb_wrapper)[source]#
Parameters:
gigl.src.training.v1.lib.training_process.setup_model_device(model, supports_distributed_training, should_enable_find_unused_parameters, device)[source]#

Configures the model by setting it on device, syncing batch norm, and wrapping the model with DDP with the relevant flags, such as find_unused_parameters :param model: Model initialized for training :type model: torch.nn.Module :param supports_distributed_training: Whether distributed training is supported, defined in the modeling task spec :type supports_distributed_training: bool :param should_enable_find_unused_parameters: Whether we allow for parameters to not receive gradient on backward pass in DDP :type should_enable_find_unused_parameters: bool :param device: Torch device to set the model to :type device: torch.device

Parameters:
  • model (torch.nn.Module)

  • supports_distributed_training (bool)

  • should_enable_find_unused_parameters (bool)

  • device (torch.device)

gigl.src.training.v1.lib.training_process.logger[source]#
gigl.src.training.v1.lib.training_process.parser[source]#