"""
This script is used to run a Kubeflow pipeline on VAI.
You have options to RUN a pipeline, COMPILE a pipeline, or RUN a pipeline without compiling it
i.e. you have a precompiled pipeline somewhere.
RUNNING A PIPELINE:
    python gigl.orchestration.kubeflow.runner --action=run  ...args
    The following arguments are required:
        --task_config_uri: GCS URI to template_or_frozen_config_uri.
        --resource_config_uri: GCS URI to resource_config_uri.
        --container_image_cuda: GiGL source code image compiled for use with cuda. See containers/Dockerfile.src
        --container_image_cpu: GiGL source code image compiled for use with cpu. See containers/Dockerfile.src
        --container_image_dataflow: GiGL source code image compiled for use with dataflow. See containers/Dockerfile.dataflow.src
    The folowing arguments are optional:
        --job_name: The name to give to the KFP job. Default is "gigl_run_at_<current_time>"
        --start_at: The component to start the pipeline at. Default is config_populator. See gigl.src.common.constants.components.GiGLComponents
        --stop_after: The component to stop the pipeline at. Default is None.
        --pipeline_tag: Optional tag, which is provided will be used to tag the pipeline description.
        --compiled_pipeline_path: The path to where to store the compiled pipeline to.
        --wait: Wait for the pipeline run to finish.
        --additional_job_args: Additional job arguments for the pipeline components, by component.
            The value has to be of form: "<gigl_component>.<arg_name>=<value>". Where <gigl_component> is one of the
            string representations of component specified in gigl.src.common.constants.components.GiGLComponents
            This argument can be repeated.
            Example:
            --additional_job_args=subgraph_sampler.additional_spark35_jar_file_uris='gs://path/to/jar'
            --additional_job_args=split_generator.some_other_arg='value'
            This passes additional_spark35_jar_file_uris="gs://path/to/jar" to subgraph_sampler at compile time and
            some_other_arg="value" to split_generator at compile time.
        --run_labels: Labels to associate with the pipeline run.
            The value has to be of form: "<label_name>=<label_value>".
            NOTE: unlike SharedResourceConfig.resource_labels, these are *only* applied to the vertex ai pipeline run.
            Example: --run_labels=gigl-integration-test=true --run_labels=user=me
        --notification_emails: Emails to send notification to.
            See https://cloud.google.com/vertex-ai/docs/pipelines/email-notifications for more details.
            Example: --notification_emails=user@example.com --notification_emails=user2@example.com
    You can alternatively run_no_compile if you have a precompiled pipeline somewhere.
    python gigl.orchestration.kubeflow.runner --action=run_no_compile ...args
    The following arguments are required:
        --task_config_uri
        --resource_config_uri
        --compiled_pipeline_path: The path to a pre-compiled pipeline; can be gcs URI (gs://...), or a local path
    The following arguments are optional:
        --job_name
        --start_at
        --stop_after
        --pipeline_tag
        --notification_emails
        --wait
COMPILING A PIPELINE:
    A strict subset of running a pipeline,
    python gigl.orchestration.kubeflow.runner --action=compile ...args
    The following arguments are required:
        --container_image_cuda
        --container_image_cpu
        --container_image_dataflow
    The following arguments are optional:
        --compiled_pipeline_path: The path to where to store the compiled pipeline to.
        --pipeline_tag: Optional tag, which is provided will be used to tag the pipeline description.
        --additional_job_args: Additional job arguments for the pipeline components, by component.
            The value has to be of form: "<gigl_component>.<arg_name>=<value>". Where <gigl_component> is one of the
            string representations of component specified in gigl.src.common.constants.components.GiGLComponents
            This argument can be repeated.
            Example:
            --additional_job_args=subgraph_sampler.additional_spark35_jar_file_uris='gs://path/to/jar'
            --additional_job_args=split_generator.some_other_arg='value'
            This passes additional_spark35_jar_file_uris="gs://path/to/jar" to subgraph_sampler at compile time and
            some_other_arg="value" to split_generator at compile time.
"""
from __future__ import annotations
import argparse
from collections import defaultdict
from enum import Enum
from pathlib import Path
from gigl.common import UriFactory
from gigl.common.constants import (
    DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU,
    DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA,
    DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU,
)
from gigl.common.logger import Logger
from gigl.orchestration.img_builder import build_and_push_customer_src_images
from gigl.orchestration.kubeflow.kfp_orchestrator import (
    DEFAULT_KFP_COMPILED_PIPELINE_DEST_PATH,
    KfpOrchestrator,
)
from gigl.orchestration.kubeflow.kfp_pipeline import SPECED_COMPONENTS
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.utils.time import current_formatted_datetime
[docs]
DEFAULT_JOB_NAME = f"gigl_run_at_{current_formatted_datetime()}" 
[docs]
DEFAULT_START_AT = GiGLComponents.ConfigPopulator.value 
[docs]
class Action(Enum):
[docs]
    RUN_NO_COMPILE = "run_no_compile" 
    @staticmethod
[docs]
    def from_string(s: str) -> Action:
        try:
            return Action(s)
        except KeyError:
            raise ValueError() 
 
_REQUIRED_RUN_FLAGS = frozenset(
    [
        "task_config_uri",
        "resource_config_uri",
    ]
)
_REQUIRED_RUN_NO_COMPILE_FLAGS = frozenset(
    [
        "task_config_uri",
        "resource_config_uri",
        "compiled_pipeline_path",
    ]
)
_REQUIRED_COMPILE_FLAGS = frozenset(
    [
        "compiled_pipeline_path",
    ]
)
def _assert_required_flags(args: argparse.Namespace) -> None:
    required_flags: frozenset[str]
    if args.action == Action.RUN:
        required_flags = _REQUIRED_RUN_FLAGS
    elif args.action == Action.RUN_NO_COMPILE:
        required_flags = _REQUIRED_RUN_NO_COMPILE_FLAGS
    elif args.action == Action.COMPILE:
        required_flags = _REQUIRED_COMPILE_FLAGS
    else:
        raise ValueError(f"Unknown action: {args.action}")
    missing_flags: list[str] = []
    missing_values: list[str] = []
    for flag in required_flags:
        if not hasattr(args, flag):
            missing_flags.append(flag)
        elif len(getattr(args, flag)) == 0:
            missing_values.append(flag)
    if missing_flags:
        raise ValueError(
            f"Missing the following flags for a {args.action} command: {missing_flags}. "
            + f"All required flags are: {list(required_flags)}."
        )
    if missing_values:
        raise ValueError(
            f"Missing values for the following flags for a {args.action} command: {missing_values}. "
            + f"All required flags are: {list(required_flags)}."
        )
    if args.action == Action.COMPILE and args.run_labels:
        raise ValueError(
            "Labels are not supported for the compile action. "
            "Please use the run action to run a pipeline with labels."
            f"Labels provided: {args.run_labels}"
        )
def _parse_additional_job_args(
    additional_job_args: list[str],
) -> dict[GiGLComponents, dict[str, str]]:
    """
    Parse the additional job arguments for the pipeline components, by component.
    Args:
        additional_job_args list[str]: Each element is of form: "<gigl_component>.<arg_name>=<value>"
            Where <gigl_component> is one of the string representations of component specified in
            gigl.src.common.constants.components.GiGLComponents
            Example:
            ["subgraph_sampler.additional_spark35_jar_file_uris=gs://path/to/jar", "split_generator.some_other_arg=value"].
    Returns dict[GiGLComponents, dict[str, str]]: The parsed additional job arguments.
            Example for the example above: {
                GiGLComponents.SubgraphSampler: {
                    "additional_spark35_jar_file_uris"="gs://path/to/jar",
                },
                GiGLComponents.SplitGenerator: {
                    "some_other_arg": "value",
                },
            }
    """
    result: dict[GiGLComponents, dict[str, str]] = defaultdict(dict)
    for job_arg in additional_job_args:
        component_dot_arg, value = job_arg.split("=", 1)
        component_str, arg = component_dot_arg.split(".", 1)  # Handle nested keys
        component = GiGLComponents(component_str)
        # Build the nested dictionary dynamically
        result[component][arg] = value
    logger.info(f"Parsed additional job args: {result}")
    return dict(result)  # Ensure the default dict is converted to a regular dict
def _parse_labels(labels: list[str]) -> dict[str, str]:
    """
    Parse the labels for the pipeline run.
    Args:
        labels list[str]: Each element is of form: "<label_name>=<label_value>"
            Example: ["gigl-integration-test=true", "user=me"].
    Returns dict[str, str]: The parsed labels.
    """
    result: dict[str, str] = {}
    for label in labels:
        label_name, label_value = label.split("=", 1)
        result[label_name] = label_value
    logger.info(f"Parsed labels: {result}")
    return result
def _get_parser() -> argparse.ArgumentParser:
    """
    Get the parser for the runner.py script.
    """
    parser = argparse.ArgumentParser(
        description="Create the KF pipeline for GNN preprocessing/training/inference"
    )
    parser.add_argument(
        "--action",
        type=Action.from_string,
        choices=list(Action),
        required=True,
    )
    parser.add_argument(
        "--job_name",
        help="Runtime argument for running the pipeline. The name to give to the KFP job.",
        default=DEFAULT_JOB_NAME,
    )
    parser.add_argument(
        "--container_image_cuda",
        default=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA,
        help="The docker image name and tag to use for cuda pipeline components ",
    )
    parser.add_argument(
        "--container_image_cpu",
        default=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU,
        help="The docker image name and tag to use for cpu pipeline components ",
    )
    parser.add_argument(
        "--container_image_dataflow",
        default=DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU,
        help="The docker image name and tag to use for the worker harness in dataflow ",
    )
    parser.add_argument(
        "--start_at",
        help="Runtime argument for running the pipeline. Specify the component where to start the pipeline.",
        choices=SPECED_COMPONENTS,
        default=DEFAULT_START_AT,
    )
    parser.add_argument(
        "--stop_after",
        help="Runtime argument for running the pipeline. Specify the component where to stop the pipeline.",
        choices=SPECED_COMPONENTS,
        default=None,
    )
    parser.add_argument(
        "--task_config_uri",
        help="Runtime argument for running the pipeline. GCS URI to template_or_frozen_config_uri.",
    )
    parser.add_argument(
        "--resource_config_uri",
        help="Runtine argument for resource and env specifications of each component",
    )
    parser.add_argument(
        "--wait",
        help="Wait for the pipeline run to finish",
        action="store_true",
    )
    parser.add_argument(
        "--pipeline_tag", "-t", help="Tag for the pipeline definition", default=None
    )
    parser.add_argument(
        "--extra_source_dir",
        help="""The path to a local dir that will be built into new src docker images based off of
        `container_image_cuda`, `container_image_cpu`, and `container_image_dataflow`""",
        default=None,
    )
    parser.add_argument(
        "--compiled_pipeline_path",
        help="A custom URI that points to where you want the compiled pipeline is to be saved to."
        + "In the case you want to run an existing pipeline that you are not compiling, this is the path to the compiled pipeline.",
    )
    parser.add_argument(
        "--export_docker_artifact_registry",
        help="The docker artifact registry to push the customer src images to. For example: us-central1-docker.pkg.dev/some_project_name/gigl-base-images",
        default=None,
    )
    parser.add_argument(
        "--additional_job_args",
        action="append",  # Allow multiple occurrences of this argument
        default=[],
        help="""Additional pipeline job arguments by component of form:
        "gigl_component.key=value,gigl_component.key_2=value_2"
        Example: --additional_job_args=subgraph_sampler.additional_spark35_jar_file_uris='gs://path/to/jar'
            --additional_job_args=split_generator.some_other_arg='value'
        This passes additional_spark35_jar_file_uris="gs://path/to/jar" to subgraph_sampler at compile time and
        some_other_arg="value" to split_generator at compile time.
        """,
    )
    parser.add_argument(
        "--run_labels",
        action="append",
        default=[],
        help="""Labels to associate with the pipeline run, of the form: --run_labels=label_name=label_value.
        Only applicable for run and run_no_compile actions.
        NOTE: unlike SharedResourceConfig.resource_labels, these are *only* applied to the vertex ai pipeline run.
        Example: --run_labels=gigl-integration-test=true --run_labels=user=me
        Which will taget the pipeline run with gigl-integration-test=true and user=me.
        """,
    )
    parser.add_argument(
        "--notification_emails",
        action="append",
        default=[],
        help="Email to send notification to. See https://cloud.google.com/vertex-ai/docs/pipelines/email-notifications for more details.",
    )
    return parser
if __name__ == "__main__":
    args = parser.parse_args()
    logger.info(f"Beginning runner.py with args: {args}")
    # Assert correctness of args
    _assert_required_flags(args)
    parsed_additional_job_args = _parse_additional_job_args(args.additional_job_args)
    parsed_labels = _parse_labels(args.run_labels)
    # Set the default value for compiled_pipeline_path as we cannot set it in argparse as
    # for compile action this is a required flag so we cannot provide it a default value.
    # See _assert_required_flags for more details.
    if args.compiled_pipeline_path:
        compiled_pipeline_path = UriFactory.create_uri(args.compiled_pipeline_path)
    else:
        compiled_pipeline_path = DEFAULT_KFP_COMPILED_PIPELINE_DEST_PATH
    cuda_container_image = args.container_image_cuda
    cpu_container_image = args.container_image_cpu
    dataflow_container_image = args.container_image_dataflow
    if args.extra_source_dir:
        # We need to rebuild the src docker images with the extra source dir
        export_docker_artifact_registry = args.export_docker_artifact_registry
        extra_source_dir_posix_path = Path(args.extra_source_dir).absolute().as_posix()
        (
            cuda_container_image,
            cpu_container_image,
            dataflow_container_image,
        ) = build_and_push_customer_src_images(
            base_image_cuda=cuda_container_image,
            base_image_cpu=cpu_container_image,
            base_image_dataflow=dataflow_container_image,
            export_docker_artifact_registry=export_docker_artifact_registry,
            context_path=extra_source_dir_posix_path,
        )
    if args.action in (Action.RUN, Action.RUN_NO_COMPILE):
        orchestrator = KfpOrchestrator()
        task_config_uri = UriFactory.create_uri(args.task_config_uri)
        resource_config_uri = UriFactory.create_uri(args.resource_config_uri)
        applied_task_identifier = AppliedTaskIdentifier(args.job_name)
        if args.action == Action.RUN:
            path = orchestrator.compile(
                cuda_container_image=cuda_container_image,
                cpu_container_image=cpu_container_image,
                dataflow_container_image=dataflow_container_image,
                dst_compiled_pipeline_path=compiled_pipeline_path,
                additional_job_args=parsed_additional_job_args,
                tag=args.pipeline_tag,
            )
            assert (
                path == compiled_pipeline_path
            ), f"Compiled pipeline path {path} does not match provided path {compiled_pipeline_path}"
        run = orchestrator.run(
            applied_task_identifier=applied_task_identifier,
            task_config_uri=task_config_uri,
            resource_config_uri=resource_config_uri,
            start_at=args.start_at,
            stop_after=args.stop_after,
            compiled_pipeline_path=compiled_pipeline_path,
            labels=parsed_labels if parsed_labels else None,
            notification_emails=args.notification_emails,
        )
        if args.wait:
            orchestrator.wait_for_completion(run=run)
    elif args.action == Action.COMPILE:
        pipeline_bundle_path = KfpOrchestrator.compile(
            cuda_container_image=cuda_container_image,
            cpu_container_image=cpu_container_image,
            dataflow_container_image=dataflow_container_image,
            dst_compiled_pipeline_path=compiled_pipeline_path,
            additional_job_args=parsed_additional_job_args,
            tag=args.pipeline_tag,
        )
        logger.info(
            f"Pipeline finished compiling, exported to: {pipeline_bundle_path.uri}"
        )
    else:
        raise ValueError(f"Unknown action: {args.action}")