gigl.src.training.v1.lib.training_process#
Attributes#
Classes#
Functions#
|
|
|
|
|
|
|
Configures the model by setting it on device, syncing batch norm, and wrapping the model with DDP with the relevant flags, such as find_unused_parameters |
Module Contents#
- class gigl.src.training.v1.lib.training_process.GnnTrainingProcess[source]#
- run(task_config_uri, device)[source]#
- Parameters:
task_config_uri (gigl.common.Uri)
device (torch.device)
- gigl.src.training.v1.lib.training_process.generate_trainer_instance(gbml_config_pb_wrapper)[source]#
- Parameters:
gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper)
- Return type:
- gigl.src.training.v1.lib.training_process.get_torch_profiler_instance(gbml_config_pb_wrapper)[source]#
- Parameters:
gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper)
- Return type:
Optional[gigl.src.common.modeling_task_specs.utils.profiler_wrapper.TorchProfiler]
- gigl.src.training.v1.lib.training_process.save_model(trainer, gbml_config_pb_wrapper)[source]#
- Parameters:
gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper)
- gigl.src.training.v1.lib.training_process.setup_model_device(model, supports_distributed_training, should_enable_find_unused_parameters, device)[source]#
Configures the model by setting it on device, syncing batch norm, and wrapping the model with DDP with the relevant flags, such as find_unused_parameters :param model: Model initialized for training :type model: torch.nn.Module :param supports_distributed_training: Whether distributed training is supported, defined in the modeling task spec :type supports_distributed_training: bool :param should_enable_find_unused_parameters: Whether we allow for parameters to not receive gradient on backward pass in DDP :type should_enable_find_unused_parameters: bool :param device: Torch device to set the model to :type device: torch.device
- Parameters:
model (torch.nn.Module)
supports_distributed_training (bool)
should_enable_find_unused_parameters (bool)
device (torch.device)