Source code for gigl.src.training.trainer

import argparse
from typing import Optional

import gigl.src.common.constants.gcs as gcs_constants
from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.file_loader import FileLoader
from gigl.src.common.utils.metrics_service_provider import initialize_metrics

# TODO: (svij) Rename Trainer to TrainerV1
from gigl.src.training.v1.trainer import Trainer as TrainerV1
from gigl.src.training.v2.glt_trainer import GLTTrainer

[docs] logger = Logger()
[docs] class Trainer: def __remove_existing_trainer_paths( self, gbml_config_pb_wrapper: GbmlConfigPbWrapper, applied_task_identifier: AppliedTaskIdentifier, ) -> None: """ Clean up paths that Trainer would be writing to in order to avoid clobbering of data. These paths are inferred from the GbmlConfig and the AppliedTaskIdentifier. :return: """ logger.info("Preparing staging paths for Trainer...") paths_to_delete = ( [ gcs_constants.get_trainer_asset_dir_gcs_path( applied_task_identifier=applied_task_identifier ) ] + gbml_config_pb_wrapper.trained_model_metadata_pb_wrapper.get_output_paths() ) file_loader = FileLoader() logger.info(f"Will delete files @ the following paths: {paths_to_delete}") file_loader.delete_files(uris=paths_to_delete)
[docs] def run( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, resource_config_uri: Uri, cpu_docker_uri: Optional[str] = None, cuda_docker_uri: Optional[str] = None, ) -> None: gbml_config_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=task_config_uri ) if ( gbml_config_wrapper.shared_config.should_skip_training and gbml_config_wrapper.shared_config.should_skip_model_evaluation ): logger.info("Skipping both training and evaluation. Exiting.") return if not gbml_config_wrapper.shared_config.should_skip_training: self.__remove_existing_trainer_paths( gbml_config_pb_wrapper=gbml_config_wrapper, applied_task_identifier=applied_task_identifier, ) if gbml_config_wrapper.should_use_glt_backend: trainer_v2 = GLTTrainer() trainer_v2.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, ) else: trainer_v1 = TrainerV1() trainer_v1.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, )
if __name__ == "__main__":
[docs] parser = argparse.ArgumentParser( description="Program to generate embeddings from a GBML model" )
parser.add_argument( "--job_name", type=str, help="Unique identifier for the job name", ) parser.add_argument( "--task_config_uri", type=str, help="Gbml config uri", ) parser.add_argument( "--resource_config_uri", type=str, help="Runtime argument for resource and env specifications of each component", ) parser.add_argument( "--cpu_docker_uri", type=str, help="User Specified or KFP compiled Docker Image for CPU training", required=False, ) parser.add_argument( "--cuda_docker_uri", type=str, help="User Specified or KFP compiled Docker Image for GPU training", required=False, ) args = parser.parse_args() if not args.job_name or not args.task_config_uri or not args.resource_config_uri: raise RuntimeError("Missing command-line arguments") applied_task_identifier = AppliedTaskIdentifier(args.job_name) task_config_uri = UriFactory.create_uri(args.task_config_uri) resource_config_uri = UriFactory.create_uri(args.resource_config_uri) cpu_docker_uri, cuda_docker_uri = args.cpu_docker_uri, args.cuda_docker_uri initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) trainer = Trainer() trainer.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, )