import argparse
import sys
import tempfile
import traceback
from typing import Optional
from gigl.common import GcsUri, LocalUri, Uri, UriFactory
from gigl.common.logger import Logger
from gigl.common.metrics.decorators import flushes_metrics, profileit
from gigl.common.utils import os_utils
from gigl.common.utils.gcs import GcsUtils
from gigl.src.common.constants import gcs as gcs_constants
from gigl.src.common.constants.metrics import TIMER_POST_PROCESSOR_S
from gigl.src.common.translators.model_eval_metrics_translator import (
    EvalMetricsCollectionTranslator,
)
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.model_eval_metrics import EvalMetricsCollection
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.file_loader import FileLoader
from gigl.src.common.utils.metrics_service_provider import (
    get_metrics_service_instance,
    initialize_metrics,
)
from gigl.src.post_process.lib.base_post_processor import BasePostProcessor
from gigl.src.post_process.utils.unenumeration import unenumerate_all_inferred_bq_assets
from snapchat.research.gbml import gbml_config_pb2
[docs]
class PostProcessor:
    def __run_post_process(
        self,
        gbml_config_pb: gbml_config_pb2.GbmlConfig,
        applied_task_identifier: AppliedTaskIdentifier,
    ):
        post_processor_cls_str: str = (
            gbml_config_pb.post_processor_config.post_processor_cls_path
        )
        kwargs = gbml_config_pb.post_processor_config.post_processor_args
        kwargs["applied_task_identifier"] = applied_task_identifier
        if post_processor_cls_str == "":
            logger.warning(
                "No post processor class path provided in config, will skip post processor"
            )
        else:
            try:
                post_processor_cls = os_utils.import_obj(post_processor_cls_str)
                post_processor: BasePostProcessor = post_processor_cls(**kwargs)
                assert isinstance(post_processor, BasePostProcessor)
                logger.info(
                    f"Instantiate class {post_processor_cls_str} with kwargs: {kwargs}"
                )
            except Exception as e:
                logger.error(
                    f"Could not instantiate class {post_processor_cls_str}: {e}"
                )
                raise e
            logger.info(
                f"Running user post processor class: {post_processor.__class__}, with config: {gbml_config_pb}"
            )
            post_processor_metrics: Optional[
                EvalMetricsCollection
            ] = post_processor.run_post_process(gbml_config_pb=gbml_config_pb)
            if post_processor_metrics is not None:
                self.__write_post_processor_metrics_to_uri(
                    model_eval_metrics=post_processor_metrics,
                    gbml_config_pb=gbml_config_pb,
                )
        # Run shared logic of cleaning up of assets considered temporary
        if gbml_config_pb.shared_config.should_skip_automatic_temp_asset_cleanup:
            logger.info(
                "Will skip automatic cleanup of temporary assets as `should_skip_automatic_temp_asset_cleanup`"
                + f" was set to truthy vlue: {gbml_config_pb.shared_config.should_skip_automatic_temp_asset_cleanup}"
            )
        else:
            gcs_utils = GcsUtils()
            temp_dir_gcs_path: GcsUri = gcs_constants.get_applied_task_temp_gcs_path(
                applied_task_identifier=applied_task_identifier
            )
            logger.info(
                f"Will automatically cleanup the temporary assets directory: ${temp_dir_gcs_path}"
            )
            gcs_utils.delete_files_in_bucket_dir(gcs_path=temp_dir_gcs_path)
    def __write_post_processor_metrics_to_uri(
        self,
        model_eval_metrics: EvalMetricsCollection,
        gbml_config_pb: gbml_config_pb2.GbmlConfig,
    ):
        file_loader = FileLoader()
        tfh = tempfile.NamedTemporaryFile(delete=False)
        local_tfh_uri = LocalUri(tfh.name)
        post_processor_log_metrics_uri = UriFactory.create_uri(
            uri=gbml_config_pb.shared_config.postprocessed_metadata.post_processor_log_metrics_uri
        )
        EvalMetricsCollectionTranslator.write_kfp_metrics_to_pipeline_metric_path(
            eval_metrics=model_eval_metrics, path=local_tfh_uri
        )
        file_loader.load_file(
            file_uri_src=local_tfh_uri, file_uri_dst=post_processor_log_metrics_uri
        )
        logger.info(f"Wrote eval metrics to {post_processor_log_metrics_uri.uri}.")
    def __should_run_unenumeration(
        self, gbml_config_wrapper: GbmlConfigPbWrapper
    ) -> bool:
        """
        When using the experimental GLT backend, we should run unenumeration in the post processor.
        """
        return gbml_config_wrapper.should_use_glt_backend
    def __run(
        self,
        applied_task_identifier: AppliedTaskIdentifier,
        task_config_uri: Uri,
    ):
        gbml_config_wrapper: GbmlConfigPbWrapper = (
            GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
                gbml_config_uri=task_config_uri
            )
        )
        if self.__should_run_unenumeration(gbml_config_wrapper=gbml_config_wrapper):
            logger.info(f"Running unenumeration for inferred assets in post processor")
            unenumerate_all_inferred_bq_assets(
                gbml_config_pb_wrapper=gbml_config_wrapper
            )
            logger.info(
                f"Finished running unenumeration for inferred assets in post processor"
            )
        self.__run_post_process(
            gbml_config_pb=gbml_config_wrapper.gbml_config_pb,
            applied_task_identifier=applied_task_identifier,
        )
    @flushes_metrics(get_metrics_service_instance_fn=get_metrics_service_instance)
    @profileit(
        metric_name=TIMER_POST_PROCESSOR_S,
        get_metrics_service_instance_fn=get_metrics_service_instance,
    )
[docs]
    def run(
        self,
        applied_task_identifier: AppliedTaskIdentifier,
        task_config_uri: Uri,
        resource_config_uri: Uri,
    ):
        try:
            return self.__run(
                applied_task_identifier=applied_task_identifier,
                task_config_uri=task_config_uri,
            )
        except Exception as e:
            logger.error(
                "Post Processor failed due to a raised exception; which will follow"
            )
            logger.error(e)
            logger.error(traceback.format_exc())
            sys.exit(f"System will now exit: {e}") 
 
if __name__ == "__main__":
[docs]
    parser = argparse.ArgumentParser(
        description="Program to run user defined logic that runs after the whole pipeline. "
        + "Subsequently cleans up any temporary assets"
    ) 
    parser.add_argument(
        "--job_name",
        type=str,
        help="Unique identifier for the job name",
    )
    parser.add_argument(
        "--task_config_uri",
        type=str,
        help="Gbml config uri",
    )
    parser.add_argument(
        "--resource_config_uri",
        type=str,
        help="Runtime argument for resource and env specifications of each component",
    )
    args = parser.parse_args()
    task_config_uri = UriFactory.create_uri(args.task_config_uri)
    resource_config_uri = UriFactory.create_uri(args.resource_config_uri)
    applied_task_identifier = AppliedTaskIdentifier(args.job_name)
    initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name)
    post_processor = PostProcessor()
    post_processor.run(
        applied_task_identifier=applied_task_identifier,
        task_config_uri=task_config_uri,
        resource_config_uri=resource_config_uri,
    )