import argparse
from typing import Optional
import torch
from google.cloud.aiplatform_v1.types import accelerator_type
from gigl.common import Uri, UriFactory
from gigl.common.constants import (
    DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU,
    DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA,
)
from gigl.common.logger import Logger
from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.utils.metrics_service_provider import initialize_metrics
from gigl.src.training.v1.lib.training_process import GnnTrainingProcess
from snapchat.research.gbml.gigl_resource_config_pb2 import (
    LocalResourceConfig,
    VertexAiResourceConfig,
)
[docs]
class Trainer:
    """
    GiGL Component that trains a GNN model using the specified task and resource configurations.
    """
[docs]
    def run(
        self,
        applied_task_identifier: AppliedTaskIdentifier,
        task_config_uri: Uri,
        resource_config_uri: Uri,
        cpu_docker_uri: Optional[str] = None,
        cuda_docker_uri: Optional[str] = None,
    ) -> None:
        resource_config = get_resource_config(resource_config_uri=resource_config_uri)
        trainer_config = resource_config.trainer_config
        is_cpu_training = self._determine_if_cpu_training(trainer_config)
        if isinstance(trainer_config, VertexAiResourceConfig):
            cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
            cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA
            container_uri = cpu_docker_uri if is_cpu_training else cuda_docker_uri
            job_args = [
                f"--job_name={applied_task_identifier}",
                f"--task_config_uri={task_config_uri}",
                f"--resource_config_uri={resource_config_uri}",
            ] + ([] if is_cpu_training else ["--use_cuda"])
            job_config = VertexAiJobConfig(
                job_name=applied_task_identifier,
                container_uri=container_uri,
                command=["python", "-m", "gigl.src.training.v1.lib.training_process"],
                args=job_args,
                environment_variables=[
                    {"name": "TF_CPP_MIN_LOG_LEVEL", "value": "3"},
                ],
                machine_type=trainer_config.machine_type,
                accelerator_type=trainer_config.gpu_type.upper().replace("-", "_"),
                accelerator_count=trainer_config.gpu_limit,
                replica_count=trainer_config.num_replicas,
                labels=resource_config.get_resource_labels(
                    component=GiGLComponents.Trainer
                ),
                timeout_s=trainer_config.timeout if trainer_config.timeout else None,
            )
            vertex_ai_service = VertexAIService(
                project=resource_config.project,
                location=resource_config.region,
                service_account=resource_config.service_account_email,
                staging_bucket=resource_config.temp_assets_regional_bucket_path.uri,
            )
            vertex_ai_service.launch_job(job_config=job_config)
        elif isinstance(trainer_config, LocalResourceConfig):
            training_process = GnnTrainingProcess()
            training_process.run(
                task_config_uri=task_config_uri,
                device=torch.device(
                    "cuda"
                    if not is_cpu_training and torch.cuda.is_available()
                    else "cpu"
                ),
            )
        else:
            raise ValueError(
                f"Unsupported trainer_config in resource_config: {type(trainer_config).__name__}"
            ) 
    def _determine_if_cpu_training(self, trainer_config) -> bool:
        """Determine whether CPU training is required based on the trainer configuration."""
        if isinstance(trainer_config, LocalResourceConfig):
            return True
        elif hasattr(trainer_config, "gpu_type") and (
            trainer_config.gpu_type
            == accelerator_type.AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED
            or trainer_config.gpu_type is None
        ):
            return True
        else:
            return False 
if __name__ == "__main__":
[docs]
    parser = argparse.ArgumentParser(
        description="Program to generate embeddings from a GBML model"
    ) 
    parser.add_argument(
        "--job_name",
        type=str,
        help="Unique identifier for the job name",
    )
    parser.add_argument(
        "--task_config_uri",
        type=str,
        help="Gbml config uri",
    )
    parser.add_argument(
        "--resource_config_uri",
        type=str,
        help="Runtime argument for resource and env specifications of each component",
    )
    parser.add_argument(
        "--cpu_docker_uri",
        type=str,
        help="User Specified or KFP compiled Docker Image for CPU training",
        required=False,
    )
    parser.add_argument(
        "--cuda_docker_uri",
        type=str,
        help="User Specified or KFP compiled Docker Image for GPU training",
        required=False,
    )
    args = parser.parse_args()
    if not args.job_name or not args.task_config_uri or not args.resource_config_uri:
        raise RuntimeError("Missing command-line arguments")
    applied_task_identifier = AppliedTaskIdentifier(args.job_name)
    task_config_uri = UriFactory.create_uri(args.task_config_uri)
    resource_config_uri = UriFactory.create_uri(args.resource_config_uri)
    cpu_docker_uri, cuda_docker_uri = args.cpu_docker_uri, args.cuda_docker_uri
    initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name)
    trainer = Trainer()
    trainer.run(
        applied_task_identifier=applied_task_identifier,
        task_config_uri=task_config_uri,
        resource_config_uri=resource_config_uri,
        cpu_docker_uri=cpu_docker_uri,
        cuda_docker_uri=cuda_docker_uri,
    )