import argparse
from typing import Optional
from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
    GiglResourceConfigWrapper,
)
from gigl.src.validation_check.libs.frozen_config_path_checks import (
    assert_preprocessed_metadata_exists,
    assert_split_generator_output_exists,
    assert_subgraph_sampler_output_exists,
    assert_trained_model_exists,
)
from gigl.src.validation_check.libs.name_checks import (
    check_if_kfp_pipeline_job_name_valid,
)
from gigl.src.validation_check.libs.resource_config_checks import (
    check_if_inferencer_resource_config_valid,
    check_if_preprocessor_resource_config_valid,
    check_if_shared_resource_config_valid,
    check_if_split_generator_resource_config_valid,
    check_if_subgraph_sampler_resource_config_valid,
    check_if_trainer_resource_config_valid,
)
from gigl.src.validation_check.libs.template_config_checks import (
    check_if_data_preprocessor_config_cls_valid,
    check_if_graph_metadata_valid,
    check_if_inferencer_cls_valid,
    check_if_post_processor_cls_valid,
    check_if_preprocessed_metadata_valid,
    check_if_split_generator_config_valid,
    check_if_subgraph_sampler_config_valid,
    check_if_task_metadata_valid,
    check_if_trainer_cls_valid,
    check_pipeline_has_valid_start_and_stop_flags,
)
from snapchat.research.gbml import gbml_config_pb2
from snapchat.research.gbml.gigl_resource_config_pb2 import GiglResourceConfig
[docs]
START_STOP_COMPONENT_TO_CLS_CHECKS_MAP = {
    # TODO: (svij-sc) Add checks as needed, otherwise we default to below anyways
    (GiGLComponents.SubgraphSampler.value, GiGLComponents.SubgraphSampler.value): [
        check_if_graph_metadata_valid,
        check_if_task_metadata_valid,
        check_if_subgraph_sampler_config_valid,
    ],
} 
[docs]
START_COMPONENT_TO_CLS_CHECKS_MAP = {
    GiGLComponents.ConfigPopulator.value: [
        check_if_graph_metadata_valid,
        check_if_task_metadata_valid,
        check_if_data_preprocessor_config_cls_valid,
        check_if_subgraph_sampler_config_valid,
        check_if_split_generator_config_valid,
        check_if_trainer_cls_valid,
        check_if_inferencer_cls_valid,
        check_if_post_processor_cls_valid,
    ],
    GiGLComponents.DataPreprocessor.value: [
        check_if_graph_metadata_valid,
        check_if_task_metadata_valid,
        check_if_data_preprocessor_config_cls_valid,
        check_if_subgraph_sampler_config_valid,
        check_if_split_generator_config_valid,
        check_if_trainer_cls_valid,
        check_if_inferencer_cls_valid,
        check_if_post_processor_cls_valid,
    ],
    GiGLComponents.SubgraphSampler.value: [
        check_if_graph_metadata_valid,
        check_if_task_metadata_valid,
        check_if_preprocessed_metadata_valid,
        check_if_subgraph_sampler_config_valid,
        check_if_split_generator_config_valid,
        check_if_trainer_cls_valid,
        check_if_inferencer_cls_valid,
        check_if_post_processor_cls_valid,
    ],
    GiGLComponents.SplitGenerator.value: [
        check_if_graph_metadata_valid,
        check_if_task_metadata_valid,
        check_if_preprocessed_metadata_valid,
        check_if_split_generator_config_valid,
        check_if_trainer_cls_valid,
        check_if_inferencer_cls_valid,
        check_if_post_processor_cls_valid,
    ],
    GiGLComponents.Trainer.value: [
        check_if_graph_metadata_valid,
        check_if_task_metadata_valid,
        check_if_preprocessed_metadata_valid,
        check_if_trainer_cls_valid,
        check_if_inferencer_cls_valid,
        check_if_post_processor_cls_valid,
    ],
    GiGLComponents.Inferencer.value: [
        check_if_graph_metadata_valid,
        check_if_task_metadata_valid,
        check_if_preprocessed_metadata_valid,
        check_if_inferencer_cls_valid,
        check_if_post_processor_cls_valid,
    ],
    GiGLComponents.PostProcessor.value: [
        check_if_graph_metadata_valid,
        check_if_task_metadata_valid,
        check_if_post_processor_cls_valid,
    ],
} 
[docs]
START_COMPONENT_TO_ASSET_CHECKS_MAP = {
    GiGLComponents.SubgraphSampler.value: [
        assert_preprocessed_metadata_exists,
    ],
    GiGLComponents.SplitGenerator.value: [
        assert_preprocessed_metadata_exists,
        assert_subgraph_sampler_output_exists,
    ],
    GiGLComponents.Trainer.value: [
        assert_preprocessed_metadata_exists,
        assert_subgraph_sampler_output_exists,
        assert_split_generator_output_exists,
    ],
    GiGLComponents.Inferencer.value: [
        assert_preprocessed_metadata_exists,
        assert_subgraph_sampler_output_exists,
        assert_trained_model_exists,
    ],
} 
[docs]
START_STOP_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP = {
    (GiGLComponents.SubgraphSampler.value, GiGLComponents.SubgraphSampler.value): [
        check_if_shared_resource_config_valid,
        check_if_subgraph_sampler_resource_config_valid,
    ],
} 
[docs]
START_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP = {
    GiGLComponents.ConfigPopulator.value: [
        check_if_shared_resource_config_valid,
    ],
    GiGLComponents.DataPreprocessor.value: [
        check_if_shared_resource_config_valid,
        check_if_preprocessor_resource_config_valid,
        check_if_subgraph_sampler_resource_config_valid,
        check_if_split_generator_resource_config_valid,
        check_if_trainer_resource_config_valid,
        check_if_inferencer_resource_config_valid,
    ],
    GiGLComponents.SubgraphSampler.value: [
        check_if_shared_resource_config_valid,
        check_if_subgraph_sampler_resource_config_valid,
        check_if_split_generator_resource_config_valid,
        check_if_trainer_resource_config_valid,
        check_if_inferencer_resource_config_valid,
    ],
    GiGLComponents.SplitGenerator.value: [
        check_if_shared_resource_config_valid,
        check_if_split_generator_resource_config_valid,
        check_if_trainer_resource_config_valid,
        check_if_inferencer_resource_config_valid,
    ],
    GiGLComponents.Trainer.value: [
        check_if_shared_resource_config_valid,
        check_if_trainer_resource_config_valid,
        check_if_inferencer_resource_config_valid,
    ],
    GiGLComponents.Inferencer.value: [
        check_if_shared_resource_config_valid,
        check_if_inferencer_resource_config_valid,
    ],
    GiGLComponents.PostProcessor.value: [
        check_if_shared_resource_config_valid,
    ],
} 
[docs]
def kfp_validation_checks(
    job_name: str,
    task_config_uri: Uri,
    start_at: str,
    resource_config_uri: Uri,
    stop_after: Optional[str] = None,
) -> None:
    # check if job_name is valid
    check_if_kfp_pipeline_job_name_valid(job_name=job_name)
    # check if start_at and stop_after aligns with glt backend use
    check_pipeline_has_valid_start_and_stop_flags(
        start_at=start_at, stop_after=stop_after, task_config_uri=task_config_uri.uri
    )
    proto_utils = ProtoUtils()
    gbml_config_pb: gbml_config_pb2.GbmlConfig = proto_utils.read_proto_from_yaml(
        uri=task_config_uri, proto_cls=gbml_config_pb2.GbmlConfig
    )
    resource_config_wrapper: GiglResourceConfigWrapper = get_resource_config(
        resource_config_uri=resource_config_uri
    )
    resource_config_pb: GiglResourceConfig = resource_config_wrapper.resource_config
    # check user defined classes and their runtime args
    if (
        stop_after is not None
        and (start_at, stop_after) in START_STOP_COMPONENT_TO_CLS_CHECKS_MAP
    ):
        for cls_check in START_STOP_COMPONENT_TO_CLS_CHECKS_MAP[(start_at, stop_after)]:
            cls_check(gbml_config_pb=gbml_config_pb)
    else:
        for cls_check in START_COMPONENT_TO_CLS_CHECKS_MAP.get(start_at, []):
            cls_check(gbml_config_pb=gbml_config_pb)
    # check the existence of needed assets
    for asset_check in START_COMPONENT_TO_ASSET_CHECKS_MAP.get(start_at, []):
        asset_check(gbml_config_pb=gbml_config_pb)
    # check if user-provided resource config is valid
    if (
        stop_after is not None
        and (start_at, stop_after) in START_STOP_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP
    ):
        for resource_config_check in START_STOP_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP[
            (start_at, stop_after)
        ]:
            resource_config_check(resource_config_pb=resource_config_pb)
    else:
        for resource_config_check in START_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP.get(
            start_at, []
        ):
            resource_config_check(resource_config_pb=resource_config_pb)
    # check if trained model file exist when skipping training
    if gbml_config_pb.shared_config.should_skip_training == True:
        assert_trained_model_exists(gbml_config_pb=gbml_config_pb)
    logger.info("[✅ SUCCESS] All checks passed successfully.") 
if __name__ == "__main__":
[docs]
    parser = argparse.ArgumentParser(
        description="Checks if config files and assets are valid for a GiGL pipeline run."
    ) 
    parser.add_argument(
        "--job_name",
        type=str,
        help="Unique identifier for the job name",
    )
    parser.add_argument(
        "--task_config_uri",
        type=str,
        help="GCS URI to template_or_frozen_config_uri",
    )
    parser.add_argument(
        "--start_at",
        type=str,
        help="Specify the component where to start the pipeline",
    )
    parser.add_argument(
        "--stop_after",
        type=str,
        help="Specify the component where to stop the pipeline",
    )
    parser.add_argument(
        "--resource_config_uri",
        type=str,
        help="Runtime argument for resource and env specifications of each component",
    )
    args = parser.parse_args()
    kfp_validation_checks(
        job_name=args.job_name,
        task_config_uri=UriFactory.create_uri(args.task_config_uri),
        start_at=args.start_at,
        resource_config_uri=UriFactory.create_uri(args.resource_config_uri),
        stop_after=args.stop_after,
    )