from typing import Union
from google.cloud.aiplatform_v1.types.accelerator_type import AcceleratorType
from gigl.common.logger import Logger
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
    GiglResourceConfigWrapper,
)
from gigl.src.validation_check.libs.utils import (
    assert_proto_field_value_is_truthy,
    assert_proto_has_field,
)
from snapchat.research.gbml import gigl_resource_config_pb2
def _check_if_dataflow_resource_config_valid(
    dataflow_resource_config_pb: gigl_resource_config_pb2.DataflowResourceConfig,
) -> None:
    """
    Checks if the provided Dataflow resource configuration is valid.
    Args:
        dataflow_resource_config_pb (gigl_resource_config_pb2.DataflowResourceConfig): The dataflow resource configuration to be checked.
    Returns:
        None
    """
    for field in ["num_workers", "max_num_workers", "disk_size_gb", "machine_type"]:
        assert_proto_field_value_is_truthy(
            proto=dataflow_resource_config_pb, field_name=field
        )
def _check_if_spark_resource_config_valid(
    spark_resource_config_pb: gigl_resource_config_pb2.SparkResourceConfig,
) -> None:
    """
    Checks if the provided Spark resource configuration is valid.
    Args:
        spark_resource_config_pb (gigl_resource_config_pb2.SparkResourceConfig): The Spark resource configuration protobuf object.
    Returns:
        None
    """
    for field in ["machine_type", "num_local_ssds", "num_replicas"]:
        assert_proto_field_value_is_truthy(
            proto=spark_resource_config_pb, field_name=field
        )
[docs]
def check_if_shared_resource_config_valid(
    resource_config_pb: gigl_resource_config_pb2.GiglResourceConfig,
) -> None:
    """
    Check if SharedResourceConfig specification is valid:
     - SharedResourceConfig or a SharedResourceConfig uri must be accessible in the resource config.
     - CommonComputeConfig must have appropriate fields defined.
    Args:
        resource_config_pb (gigl_resource_config_pb2.GiglResourceConfig): The resource config to be checked.
    Returns:
        None
    """
    logger.info("Config validation check: if resource config shared_resource is valid.")
    wrapper = GiglResourceConfigWrapper(resource_config=resource_config_pb)
    assert (
        wrapper.shared_resource_config
    ), "Invalid 'shared_resource_config'; must provide shared_resource_config."
    assert_proto_has_field(
        proto=wrapper.shared_resource_config, field_name="common_compute_config"
    )
    common_compute_config_pb = wrapper.shared_resource_config.common_compute_config
    for field in [
        "project",
        "region",
        "temp_assets_bucket",
        "temp_regional_assets_bucket",
        "perm_assets_bucket",
        "temp_assets_bq_dataset_name",
        "embedding_bq_dataset_name",
        "gcp_service_account_email",
        "dataflow_runner",
    ]:
        assert_proto_field_value_is_truthy(
            proto=common_compute_config_pb, field_name=field
        ) 
[docs]
def check_if_preprocessor_resource_config_valid(
    resource_config_pb: gigl_resource_config_pb2.GiglResourceConfig,
) -> None:
    logger.info(
        "Config validation check: if resource config preprocessor_config is valid."
    )
    preprocessor_config: gigl_resource_config_pb2.DataPreprocessorConfig = (
        resource_config_pb.preprocessor_config
    )
    _check_if_dataflow_resource_config_valid(
        dataflow_resource_config_pb=preprocessor_config.node_preprocessor_config
    )
    _check_if_dataflow_resource_config_valid(
        dataflow_resource_config_pb=preprocessor_config.edge_preprocessor_config
    ) 
[docs]
def check_if_subgraph_sampler_resource_config_valid(
    resource_config_pb: gigl_resource_config_pb2.GiglResourceConfig,
) -> None:
    logger.info(
        "Config validation check: if resource config subgraph_sampler_config is valid."
    )
    _check_if_spark_resource_config_valid(
        spark_resource_config_pb=resource_config_pb.subgraph_sampler_config
    ) 
[docs]
def check_if_split_generator_resource_config_valid(
    resource_config_pb: gigl_resource_config_pb2.GiglResourceConfig,
) -> None:
    logger.info(
        "Config validation check: if resource config split_generator_config is valid."
    )
    _check_if_spark_resource_config_valid(
        spark_resource_config_pb=resource_config_pb.split_generator_config
    ) 
[docs]
def check_if_trainer_resource_config_valid(
    resource_config_pb: gigl_resource_config_pb2.GiglResourceConfig,
) -> None:
    logger.info("Config validation check: if resource config trainer_config is valid.")
    wrapper = GiglResourceConfigWrapper(resource_config=resource_config_pb)
    assert (
        wrapper.trainer_config
    ), "Invalid 'trainer_config'; must provide trainer_config."
    trainer_config: Union[
        gigl_resource_config_pb2.LocalResourceConfig,
        gigl_resource_config_pb2.VertexAiResourceConfig,
        gigl_resource_config_pb2.KFPResourceConfig,
    ] = wrapper.trainer_config
    if isinstance(trainer_config, gigl_resource_config_pb2.LocalResourceConfig):
        assert_proto_field_value_is_truthy(
            proto=trainer_config, field_name="num_workers"
        )
    else:
        # Case where trainer config is gigl_resource_config_pb2.VertexAiResourceConfig or gigl_resource_config_pb2.KFPResourceConfig
        if isinstance(trainer_config, gigl_resource_config_pb2.VertexAiResourceConfig):
            assert_proto_field_value_is_truthy(
                proto=trainer_config, field_name="machine_type"
            )
        elif isinstance(trainer_config, gigl_resource_config_pb2.KFPResourceConfig):
            for field in [
                "cpu_request",
                "memory_request",
            ]:
                assert_proto_field_value_is_truthy(
                    proto=trainer_config, field_name=field
                )
        else:
            raise ValueError(
                f"""Expected distributed trainer config to be one of {gigl_resource_config_pb2.LocalResourceConfig.__name__}, 
                {gigl_resource_config_pb2.VertexAiResourceConfig.__name__}, 
                or {gigl_resource_config_pb2.KFPResourceConfig.__name__}. 
                Got {type(trainer_config)}"""
            )
        for field in [
            "gpu_type",
            "num_replicas",
        ]:
            assert_proto_field_value_is_truthy(proto=trainer_config, field_name=field)
        if trainer_config.gpu_type == AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name:  # type: ignore
            assert (
                trainer_config.gpu_limit == 0
            ), f"""gpu_limit must be equal to 0 for cpu training, indicated by provided gpu_type {trainer_config.gpu_type}. 
                Got gpu_limit {trainer_config.gpu_limit}"""
        else:
            assert (
                trainer_config.gpu_limit > 0
            ), f"""gpu_limit must be greater than 0 for gpu training, indicated by provided gpu_type {trainer_config.gpu_type}. 
                Got gpu_limit {trainer_config.gpu_limit}. Use gpu_type {AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name} for cpu training."""  # type: ignore 
[docs]
def check_if_inferencer_resource_config_valid(
    resource_config_pb: gigl_resource_config_pb2.GiglResourceConfig,
) -> None:
    logger.info(
        "Config validation check: if resource config inferencer_config is valid."
    )
    resource_config_wrapper = GiglResourceConfigWrapper(
        resource_config=resource_config_pb
    )
    inferencer_config = resource_config_wrapper.inferencer_config
    if isinstance(inferencer_config, gigl_resource_config_pb2.DataflowResourceConfig):
        _check_if_dataflow_resource_config_valid(
            dataflow_resource_config_pb=inferencer_config
        )
    elif isinstance(inferencer_config, gigl_resource_config_pb2.VertexAiResourceConfig):
        assert_proto_field_value_is_truthy(
            proto=inferencer_config, field_name="machine_type"
        )
        assert_proto_field_value_is_truthy(
            proto=inferencer_config, field_name="gpu_type"
        )
        assert_proto_field_value_is_truthy(
            proto=inferencer_config, field_name="num_replicas"
        )
        if inferencer_config.gpu_type == AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name:  # type: ignore
            assert (
                inferencer_config.gpu_limit == 0
            ), f"""gpu_limit must be equal to 0 for cpu training, indicated by provided gpu_type {inferencer_config.gpu_type}. 
                Got gpu_limit {inferencer_config.gpu_limit}"""
        else:
            assert (
                inferencer_config.gpu_limit > 0
            ), f"""gpu_limit must be greater than 0 for gpu training, indicated by provided gpu_type {inferencer_config.gpu_type}. 
                Got gpu_limit {inferencer_config.gpu_limit}. Use gpu_type {AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name} for cpu training."""  # type: ignore
    elif isinstance(inferencer_config, gigl_resource_config_pb2.LocalResourceConfig):
        assert_proto_field_value_is_truthy(
            proto=inferencer_config, field_name="num_workers"
        )
    else:
        raise ValueError(
            f"""Expected inferencer config to be one of {gigl_resource_config_pb2.DataflowResourceConfig.__name__}, 
            {gigl_resource_config_pb2.VertexAiResourceConfig.__name__}, 
            or {gigl_resource_config_pb2.LocalResourceConfig.__name__}. 
            Got {type(inferencer_config)}"""
        )