"""Class for interacting with Vertex AI.
Below are some brief definitions of the terminology used by Vertex AI Pipelines:
Resource name: A globally unique identifier for the pipeline, follows https://google.aip.dev/122 and is of the form projects/<project-id>/locations/<location>/pipelineJobs/<job-name>
Job name: aka job_id aka PipelineJob.name the name of a pipeline run, must be unique for a given project and location
Display name: AFAICT purely cosmetic name for a pipeline, can be filtered on but does not show up in the UI
Pipeline name: The name for the pipeline supplied by the pipeline definition (pipeline.yaml).
And a walkthrough to explain how the terminology is used:
```py
@kfp.dsl.component
def source() -> int:
    return 42
@kfp.dsl.component
def doubler(a: int) -> int:
    return a * 2
@kfp.dsl.component
def adder(a: int, b: int) -> int:
    return a + b
@kfp.dsl.pipeline
def get_pipeline() -> int: # NOTE: `get_pipeline` here is  the Pipeline name
    source_task = source()
    double_task = doubler(a=source_task.output)
    adder_task = adder(a=source_task.output, b=double_task.output)
    return adder_task.output
tempdir = tempfile.TemporaryDirectory()
tf = os.path.join(tempdir.name, "pipeline.yaml")
print(f"Writing pipeline definition to {tf}")
kfp.compiler.Compiler().compile(get_pipeline, tf)
job = aip.PipelineJob(
        display_name="this_is_our_pipeline_display_name",
        template_path=tf,
        pipeline_root="gs://my-bucket/pipeline-root",
)
    job.submit(service_account="my-sa@my-project.gserviceaccount.com")
```
Which outputs the following:
Creating PipelineJob
PipelineJob created. Resource name: projects/my-project-id/locations/us-central1/pipelineJobs/get-pipeline-20250226170755
To use this PipelineJob in another session:
pipeline_job = aiplatform.PipelineJob.get('projects/my-project-id/locations/us-central1/pipelineJobs/get-pipeline-20250226170755')
View Pipeline Job:
https://console.cloud.google.com/vertex-ai/locations/us-central1/pipelines/runs/get-pipeline-20250226170755?project=my-project-id
Associating projects/my-project-id/locations/us-central1/pipelineJobs/get-pipeline-20250226170755 to Experiment: example-experiment
And `job` has some properties set as well:
```py
print(f"{job.display_name=}") # job.display_name='this_is_our_pipeline_display_name'
print(f"{job.resource_name=}") # job.resource_name='projects/my-project-id/locations/us-central1/pipelineJobs/get-pipeline-20250226170755'
print(f"{job.name=}") # job.name='get-pipeline-20250226170755' # NOTE: by default, the "job name" is the pipeline name + datetime
```
"""
import datetime
import time
from dataclasses import dataclass
from typing import Final, Optional, Union
from google.cloud import aiplatform
from google.cloud.aiplatform_v1.types import (
    ContainerSpec,
    DiskSpec,
    MachineSpec,
    WorkerPoolSpec,
    env_var,
)
from gigl.common import GcsUri, Uri
from gigl.common.logger import Logger
[docs]
LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY: Final[
    str
] = "LEADER_WORKER_INTERNAL_IP_FILE_PATH" 
[docs]
DEFAULT_PIPELINE_TIMEOUT_S: Final[int] = 60 * 60 * 36  # 36 hours 
[docs]
DEFAULT_CUSTOM_JOB_TIMEOUT_S: Final[int] = 60 * 60 * 24  # 24 hours 
@dataclass
[docs]
class VertexAiJobConfig:
[docs]
    args: Optional[list[str]] = None 
[docs]
    environment_variables: Optional[list[dict[str, str]]] = None 
[docs]
    machine_type: str = "n1-standard-4" 
[docs]
    accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED" 
[docs]
    accelerator_count: int = 0 
[docs]
    boot_disk_type: str = "pd-ssd"  # Persistent Disk SSD 
[docs]
    boot_disk_size_gb: int = 100  # Default disk size in GB 
[docs]
    labels: Optional[dict[str, str]] = None 
[docs]
    timeout_s: Optional[
        int
    ] = None  # Will default to DEFAULT_CUSTOM_JOB_TIMEOUT_S if not provided 
[docs]
    enable_web_access: bool = True 
 
[docs]
class VertexAIService:
    """
    A class representing a Vertex AI service.
    Args:
        project (str): The project ID.
        location (str): The location of the service.
        service_account (str): The service account to use for authentication.
        staging_bucket (str): The staging bucket for the service.
    """
    def __init__(
        self,
        project: str,
        location: str,
        service_account: str,
        staging_bucket: str,
    ):
        self._project = project
        self._location = location
        self._service_account = service_account
        self._staging_bucket = staging_bucket
        aiplatform.init(
            project=self._project,
            location=self._location,
            staging_bucket=self._staging_bucket,
        )
    @property
[docs]
    def project(self) -> str:
        """The GCP project that is being used for this service."""
        return self._project 
[docs]
    def launch_job(self, job_config: VertexAiJobConfig) -> aiplatform.CustomJob:
        """
        Launch a Vertex AI CustomJob.
        See the docs for more info.
        https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.CustomJob
        Args:
            job_config (VertexAiJobConfig): The configuration for the job.
        Returns:
            The completed CustomJob.
        """
        logger.info(f"Running Vertex AI job: {job_config.job_name}")
        machine_spec = _create_machine_spec(job_config)
        # This file is used to store the leader worker's internal IP address.
        # Whenever `connect_worker_pool()` is called, the leader worker will
        # write its internal IP address to this file. The other workers will
        # read this file to get the leader worker's internal IP address.
        # See connect_worker_pool() implementation for more details.
        leader_worker_internal_ip_file_path = GcsUri.join(
            self._staging_bucket,
            job_config.job_name,
            datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
            "leader_worker_internal_ip.txt",
        )
        env_vars = [
            env_var.EnvVar(
                name=LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY,
                value=leader_worker_internal_ip_file_path.uri,
            )
        ]
        container_spec = _create_container_spec(job_config, env_vars)
        disk_spec = _create_disk_spec(job_config)
        assert (
            job_config.replica_count >= 1
        ), "Replica count can be at minumum 1, i.e. leader worker"
        leader_worker_spec = WorkerPoolSpec(
            machine_spec=machine_spec,
            container_spec=container_spec,
            disk_spec=disk_spec,
            replica_count=1,
        )
        worker_pool_specs: list[WorkerPoolSpec] = [leader_worker_spec]
        if job_config.replica_count > 1:
            worker_spec = WorkerPoolSpec(
                machine_spec=machine_spec,
                container_spec=container_spec,
                disk_spec=disk_spec,
                replica_count=job_config.replica_count - 1,
            )
            worker_pool_specs.append(worker_spec)
        logger.info(
            f"Running Custom job {job_config.job_name} with worker_pool_specs {worker_pool_specs}, in project: {self._project}/{self._location} using staging bucket: {self._staging_bucket}, and attached labels: {job_config.labels}"
        )
        if not job_config.timeout_s:
            logger.info(
                f"No timeout set for Vertex AI job, setting default timeout to {DEFAULT_CUSTOM_JOB_TIMEOUT_S/60/60} hours"
            )
            job_config.timeout_s = DEFAULT_CUSTOM_JOB_TIMEOUT_S
        else:
            logger.info(
                f"Running Vertex AI job with timeout {job_config.timeout_s} seconds"
            )
        job = aiplatform.CustomJob(
            display_name=job_config.job_name,
            worker_pool_specs=worker_pool_specs,
            project=self._project,
            location=self._location,
            labels=job_config.labels,
            staging_bucket=self._staging_bucket,
        )
        job.submit(
            service_account=self._service_account,
            timeout=job_config.timeout_s,
            enable_web_access=job_config.enable_web_access,
        )
        job.wait_for_resource_creation()
        logger.info(f"Created job: {job.resource_name}")
        # Copying https://github.com/googleapis/python-aiplatform/blob/v1.48.0/google/cloud/aiplatform/jobs.py#L207-L215
        # Since for some reason upgrading from VertexAI v1.27.1 to v1.48.0
        # caused the logs to occasionally not be printed.
        logger.info(
            f"See job logs at: https://console.cloud.google.com/ai/platform/locations/{self._location}/training/{job.name}?project={self._project}"
        )
        job.wait_for_completion()
        return job 
[docs]
    def launch_graph_store_job(
        self,
        compute_pool_job_config: VertexAiJobConfig,
        storage_pool_job_config: VertexAiJobConfig,
    ) -> aiplatform.CustomJob:
        """Launch a Vertex AI Graph Store job.
        This launches one Vertex AI CustomJob with two worker pools, see
        https://cloud.google.com/vertex-ai/docs/training/distributed-training
        for more details.
        NOTE:
            We use the job_name, timeout, and enable_web_access from the compute pool job config.
            These fields, if set on the storage pool job config, will be ignored.
        Args:
            compute_pool_job_config (VertexAiJobConfig): The configuration for the compute pool job.
            storage_pool_job_config (VertexAiJobConfig): The configuration for the storage pool job.
        Returns:
            The completed CustomJob.
        """
        storage_machine_spec = _create_machine_spec(storage_pool_job_config)
        compute_machine_spec = _create_machine_spec(compute_pool_job_config)
        storage_disk_spec = _create_disk_spec(storage_pool_job_config)
        compute_disk_spec = _create_disk_spec(compute_pool_job_config)
        storage_container_spec = _create_container_spec(storage_pool_job_config)
        compute_container_spec = _create_container_spec(compute_pool_job_config)
        worker_pool_specs: list[Union[WorkerPoolSpec, dict]] = []
        leader_worker_spec = WorkerPoolSpec(
            machine_spec=compute_machine_spec,
            container_spec=compute_container_spec,
            disk_spec=compute_disk_spec,
            replica_count=1,
        )
        worker_pool_specs.append(leader_worker_spec)
        if compute_pool_job_config.replica_count > 1:
            worker_spec = WorkerPoolSpec(
                machine_spec=compute_machine_spec,
                container_spec=compute_container_spec,
                disk_spec=compute_disk_spec,
                replica_count=compute_pool_job_config.replica_count - 1,
            )
            worker_pool_specs.append(worker_spec)
        else:
            worker_pool_specs.append({})
        worker_spec = WorkerPoolSpec(
            machine_spec=storage_machine_spec,
            container_spec=storage_container_spec,
            disk_spec=storage_disk_spec,
            replica_count=compute_pool_job_config.replica_count,
        )
        worker_pool_specs.append(worker_spec)
        job = aiplatform.CustomJob(
            display_name=compute_pool_job_config.job_name,
            worker_pool_specs=worker_pool_specs,
            project=self._project,
            location=self._location,
            labels=compute_pool_job_config.labels,
            staging_bucket=self._staging_bucket,
        )
        job.submit(
            service_account=self._service_account,
            timeout=compute_pool_job_config.timeout_s,
            enable_web_access=compute_pool_job_config.enable_web_access,
        )
        job.wait_for_resource_creation()
        logger.info(f"Created job: {job.resource_name}")
        # Copying https://github.com/googleapis/python-aiplatform/blob/v1.48.0/google/cloud/aiplatform/jobs.py#L207-L215
        # Since for some reason upgrading from VertexAI v1.27.1 to v1.48.0
        # caused the logs to occasionally not be printed.
        logger.info(
            f"See job logs at: https://console.cloud.google.com/ai/platform/locations/{self._location}/training/{job.name}?project={self._project}"
        )
        job.wait_for_completion()
        return job 
[docs]
    def run_pipeline(
        self,
        display_name: str,
        template_path: Uri,
        run_keyword_args: dict[str, str],
        job_id: Optional[str] = None,
        labels: Optional[dict[str, str]] = None,
        experiment: Optional[str] = None,
    ) -> aiplatform.PipelineJob:
        """
        Runs a pipeline using the Vertex AI Pipelines service.
        For more info, see the Vertex AI docs
        https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.PipelineJob#google_cloud_aiplatform_PipelineJob_submit
        Args:
            display_name (str): The display of the pipeline.
            template_path (Uri): The path to the compiled pipeline YAML.
            run_keyword_args (dict[str, str]): Runtime arguements passed to your pipeline.
            job_id (Optional[str]): The ID of the job. If not provided will be the *pipeline_name* + datetime.
                                    Note: The pipeline_name and display_name are *not* the same.
                                    Note: pipeline_name comes is defined in the `template_path` and ultimately comes from Python pipeline definition.
                                    If provided, must be unique.
            labels (Optional[dict[str, str]]): Labels to associate with the run.
            experiment (Optional[str]): The name of the experiment to associate the run with.
        Returns:
            The PipelineJob created.
        """
        job = aiplatform.PipelineJob(
            display_name=display_name,
            template_path=template_path.uri,
            parameter_values=run_keyword_args,
            job_id=job_id,
            project=self._project,
            location=self._location,
            labels=labels,
        )
        job.submit(service_account=self._service_account, experiment=experiment)
        logger.info(f"Created run: {job.resource_name}")
        if experiment:
            logger.info(
                f"Associated run {job.resource_name} with experiment: {experiment}"
            )
        if labels:
            logger.info(f"Associated run {job.resource_name} with labels: {labels}")
        return job 
[docs]
    def get_pipeline_job_from_job_name(self, job_name: str) -> aiplatform.PipelineJob:
        """Fetches the pipeline job with the given job name."""
        return aiplatform.PipelineJob.get(
            f"projects/{self._project}/locations/{self._location}/pipelineJobs/{job_name}"
        ) 
    @staticmethod
[docs]
    def get_pipeline_run_url(project: str, location: str, job_name: str) -> str:
        """Returns the URL for the pipeline run."""
        return f"https://console.cloud.google.com/vertex-ai/locations/{location}/pipelines/runs/{job_name}?project={project}" 
    @staticmethod
[docs]
    def wait_for_run_completion(
        resource_name: str,
        timeout: float = DEFAULT_PIPELINE_TIMEOUT_S,
        polling_period_s: int = 60,
    ) -> None:
        """
        Waits for a run to complete.
        Args:
            resource_name (str): The resource name of the run.
            timeout (float): The maximum time to wait for the run to complete, in seconds. Defaults to 7200.
            polling_period_s (int): The time to wait between polling the run status, in seconds. Defaults to 60.
        Returns:
            None
        """
        start_time = time.time()
        logger.info(f"Waiting for run to complete: {resource_name}")
        run = aiplatform.PipelineJob.get(resource_name=resource_name)
        while start_time + timeout > time.time():
            # Note that accesses to `run.state` cause a network call under the hood.
            # We should be careful with accessing this too frequently, and "cache"
            # the state if we need to access it multiple times in short succession.
            state = run.state
            logger.info(
                f"Run {resource_name} in state: {state.name if state else state}"
            )
            if state == aiplatform.gapic.PipelineState.PIPELINE_STATE_SUCCEEDED:
                logger.info("Vertex AI finished with status Succeeded!")
                return
            elif state in (
                aiplatform.gapic.PipelineState.PIPELINE_STATE_FAILED,
                aiplatform.gapic.PipelineState.PIPELINE_STATE_CANCELLED,
            ):
                logger.warning(f"Vertex AI run stopped with status: {state.name}.")
                logger.warning(
                    f"See run at: {VertexAIService.get_pipeline_run_url(run.project, run.location, run.name)}"
                )
                raise RuntimeError(f"Vertex AI run stopped with status: {state.name}.")
            time.sleep(polling_period_s)
        else:
            logger.warning("Timeout reached. Stopping the run.")
            logger.warning(
                f"See run at: {VertexAIService.get_pipeline_run_url(run.project, run.location, run.name)}"
            )
            run.cancel()
            raise RuntimeError(
                f"Vertex AI run stopped with status: {run.state}. "
                f"Please check the Vertex AI page to trace down the error."
            ) 
 
def _create_machine_spec(job_config: VertexAiJobConfig) -> MachineSpec:
    """Get the machine spec for a job config."""
    machine_spec = MachineSpec(
        machine_type=job_config.machine_type,
        accelerator_type=job_config.accelerator_type,
        accelerator_count=job_config.accelerator_count,
    )
    return machine_spec
def _create_container_spec(
    job_config: VertexAiJobConfig, env_vars: Optional[list[env_var.EnvVar]] = None
) -> ContainerSpec:
    """Get the container spec for a job config."""
    container_spec = ContainerSpec(
        image_uri=job_config.container_uri,
        command=job_config.command,
        args=job_config.args,
        env=env_vars,
    )
    return container_spec
def _create_disk_spec(job_config: VertexAiJobConfig) -> DiskSpec:
    """Get the disk spec for a job config."""
    disk_spec = DiskSpec(
        boot_disk_type=job_config.boot_disk_type,
        boot_disk_size_gb=job_config.boot_disk_size_gb,
    )
    return disk_spec