import argparse
from typing import Optional
import gigl.src.common.constants.gcs as gcs_constants
from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.file_loader import FileLoader
from gigl.src.common.utils.metrics_service_provider import initialize_metrics
# TODO: (svij) Rename Trainer to TrainerV1
from gigl.src.training.v1.trainer import Trainer as TrainerV1
from gigl.src.training.v2.glt_trainer import GLTTrainer
[docs]
class Trainer:
    def __remove_existing_trainer_paths(
        self,
        gbml_config_pb_wrapper: GbmlConfigPbWrapper,
        applied_task_identifier: AppliedTaskIdentifier,
    ) -> None:
        """
        Clean up paths that Trainer would be writing to in order to avoid clobbering of data.
        These paths are inferred from the GbmlConfig and the AppliedTaskIdentifier.
        :return:
        """
        logger.info("Preparing staging paths for Trainer...")
        paths_to_delete = (
            [
                gcs_constants.get_trainer_asset_dir_gcs_path(
                    applied_task_identifier=applied_task_identifier
                )
            ]
            + gbml_config_pb_wrapper.trained_model_metadata_pb_wrapper.get_output_paths()
        )
        file_loader = FileLoader()
        logger.info(f"Will delete files @ the following paths: {paths_to_delete}")
        file_loader.delete_files(uris=paths_to_delete)
[docs]
    def run(
        self,
        applied_task_identifier: AppliedTaskIdentifier,
        task_config_uri: Uri,
        resource_config_uri: Uri,
        cpu_docker_uri: Optional[str] = None,
        cuda_docker_uri: Optional[str] = None,
    ) -> None:
        gbml_config_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
            gbml_config_uri=task_config_uri
        )
        if (
            gbml_config_wrapper.shared_config.should_skip_training
            and gbml_config_wrapper.shared_config.should_skip_model_evaluation
        ):
            logger.info("Skipping both training and evaluation. Exiting.")
            return
        if not gbml_config_wrapper.shared_config.should_skip_training:
            self.__remove_existing_trainer_paths(
                gbml_config_pb_wrapper=gbml_config_wrapper,
                applied_task_identifier=applied_task_identifier,
            )
        if gbml_config_wrapper.should_use_glt_backend:
            trainer_v2 = GLTTrainer()
            trainer_v2.run(
                applied_task_identifier=applied_task_identifier,
                task_config_uri=task_config_uri,
                resource_config_uri=resource_config_uri,
                cpu_docker_uri=cpu_docker_uri,
                cuda_docker_uri=cuda_docker_uri,
            )
        else:
            trainer_v1 = TrainerV1()
            trainer_v1.run(
                applied_task_identifier=applied_task_identifier,
                task_config_uri=task_config_uri,
                resource_config_uri=resource_config_uri,
                cpu_docker_uri=cpu_docker_uri,
                cuda_docker_uri=cuda_docker_uri,
            ) 
 
if __name__ == "__main__":
[docs]
    parser = argparse.ArgumentParser(
        description="Program to generate embeddings from a GBML model"
    ) 
    parser.add_argument(
        "--job_name",
        type=str,
        help="Unique identifier for the job name",
    )
    parser.add_argument(
        "--task_config_uri",
        type=str,
        help="Gbml config uri",
    )
    parser.add_argument(
        "--resource_config_uri",
        type=str,
        help="Runtime argument for resource and env specifications of each component",
    )
    parser.add_argument(
        "--cpu_docker_uri",
        type=str,
        help="User Specified or KFP compiled Docker Image for CPU training",
        required=False,
    )
    parser.add_argument(
        "--cuda_docker_uri",
        type=str,
        help="User Specified or KFP compiled Docker Image for GPU training",
        required=False,
    )
    args = parser.parse_args()
    if not args.job_name or not args.task_config_uri or not args.resource_config_uri:
        raise RuntimeError("Missing command-line arguments")
    applied_task_identifier = AppliedTaskIdentifier(args.job_name)
    task_config_uri = UriFactory.create_uri(args.task_config_uri)
    resource_config_uri = UriFactory.create_uri(args.resource_config_uri)
    cpu_docker_uri, cuda_docker_uri = args.cpu_docker_uri, args.cuda_docker_uri
    initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name)
    trainer = Trainer()
    trainer.run(
        applied_task_identifier=applied_task_identifier,
        task_config_uri=task_config_uri,
        resource_config_uri=resource_config_uri,
        cpu_docker_uri=cpu_docker_uri,
        cuda_docker_uri=cuda_docker_uri,
    )