Source code for gigl.src.post_process.post_processor

import argparse
import sys
import tempfile
import traceback
from typing import Optional

from gigl.common import GcsUri, LocalUri, Uri, UriFactory
from gigl.common.logger import Logger
from gigl.common.metrics.decorators import flushes_metrics, profileit
from gigl.common.utils import os_utils
from gigl.common.utils.gcs import GcsUtils
from gigl.src.common.constants import gcs as gcs_constants
from gigl.src.common.constants.metrics import TIMER_POST_PROCESSOR_S
from gigl.src.common.translators.model_eval_metrics_translator import (
    EvalMetricsCollectionTranslator,
)
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.model_eval_metrics import EvalMetricsCollection
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.file_loader import FileLoader
from gigl.src.common.utils.metrics_service_provider import (
    get_metrics_service_instance,
    initialize_metrics,
)
from gigl.src.post_process.lib.base_post_processor import BasePostProcessor
from gigl.src.post_process.utils.unenumeration import unenumerate_all_inferred_bq_assets
from snapchat.research.gbml import gbml_config_pb2

[docs] logger = Logger()
[docs] class PostProcessor: def __run_post_process( self, gbml_config_pb: gbml_config_pb2.GbmlConfig, applied_task_identifier: AppliedTaskIdentifier, ): post_processor_cls_str: str = ( gbml_config_pb.post_processor_config.post_processor_cls_path ) kwargs = gbml_config_pb.post_processor_config.post_processor_args kwargs["applied_task_identifier"] = applied_task_identifier if post_processor_cls_str == "": logger.warning( "No post processor class path provided in config, will skip post processor" ) else: try: post_processor_cls = os_utils.import_obj(post_processor_cls_str) post_processor: BasePostProcessor = post_processor_cls(**kwargs) assert isinstance(post_processor, BasePostProcessor) logger.info( f"Instantiate class {post_processor_cls_str} with kwargs: {kwargs}" ) except Exception as e: logger.error( f"Could not instantiate class {post_processor_cls_str}: {e}" ) raise e logger.info( f"Running user post processor class: {post_processor.__class__}, with config: {gbml_config_pb}" ) post_processor_metrics: Optional[ EvalMetricsCollection ] = post_processor.run_post_process(gbml_config_pb=gbml_config_pb) if post_processor_metrics is not None: self.__write_post_processor_metrics_to_uri( model_eval_metrics=post_processor_metrics, gbml_config_pb=gbml_config_pb, ) # Run shared logic of cleaning up of assets considered temporary if gbml_config_pb.shared_config.should_skip_automatic_temp_asset_cleanup: logger.info( "Will skip automatic cleanup of temporary assets as `should_skip_automatic_temp_asset_cleanup`" + f" was set to truthy vlue: {gbml_config_pb.shared_config.should_skip_automatic_temp_asset_cleanup}" ) else: gcs_utils = GcsUtils() temp_dir_gcs_path: GcsUri = gcs_constants.get_applied_task_temp_gcs_path( applied_task_identifier=applied_task_identifier ) logger.info( f"Will automatically cleanup the temporary assets directory: ${temp_dir_gcs_path}" ) gcs_utils.delete_files_in_bucket_dir(gcs_path=temp_dir_gcs_path) def __write_post_processor_metrics_to_uri( self, model_eval_metrics: EvalMetricsCollection, gbml_config_pb: gbml_config_pb2.GbmlConfig, ): file_loader = FileLoader() tfh = tempfile.NamedTemporaryFile(delete=False) local_tfh_uri = LocalUri(tfh.name) post_processor_log_metrics_uri = UriFactory.create_uri( uri=gbml_config_pb.shared_config.postprocessed_metadata.post_processor_log_metrics_uri ) EvalMetricsCollectionTranslator.write_kfp_metrics_to_pipeline_metric_path( eval_metrics=model_eval_metrics, path=local_tfh_uri ) file_loader.load_file( file_uri_src=local_tfh_uri, file_uri_dst=post_processor_log_metrics_uri ) logger.info(f"Wrote eval metrics to {post_processor_log_metrics_uri.uri}.") def __should_run_unenumeration( self, gbml_config_wrapper: GbmlConfigPbWrapper ) -> bool: """ When using the experimental GLT backend, we should run unenumeration in the post processor. """ return gbml_config_wrapper.should_use_glt_backend def __run( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, ): gbml_config_wrapper: GbmlConfigPbWrapper = ( GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=task_config_uri ) ) if self.__should_run_unenumeration(gbml_config_wrapper=gbml_config_wrapper): logger.info(f"Running unenumeration for inferred assets in post processor") unenumerate_all_inferred_bq_assets( gbml_config_pb_wrapper=gbml_config_wrapper ) logger.info( f"Finished running unenumeration for inferred assets in post processor" ) self.__run_post_process( gbml_config_pb=gbml_config_wrapper.gbml_config_pb, applied_task_identifier=applied_task_identifier, ) @flushes_metrics(get_metrics_service_instance_fn=get_metrics_service_instance) @profileit( metric_name=TIMER_POST_PROCESSOR_S, get_metrics_service_instance_fn=get_metrics_service_instance, )
[docs] def run( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, resource_config_uri: Uri, ): try: return self.__run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, ) except Exception as e: logger.error( "Post Processor failed due to a raised exception; which will follow" ) logger.error(e) logger.error(traceback.format_exc()) sys.exit(f"System will now exit: {e}")
if __name__ == "__main__":
[docs] parser = argparse.ArgumentParser( description="Program to run user defined logic that runs after the whole pipeline. " + "Subsequently cleans up any temporary assets" )
parser.add_argument( "--job_name", type=str, help="Unique identifier for the job name", ) parser.add_argument( "--task_config_uri", type=str, help="Gbml config uri", ) parser.add_argument( "--resource_config_uri", type=str, help="Runtime argument for resource and env specifications of each component", ) args = parser.parse_args() task_config_uri = UriFactory.create_uri(args.task_config_uri) resource_config_uri = UriFactory.create_uri(args.resource_config_uri) applied_task_identifier = AppliedTaskIdentifier(args.job_name) initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) post_processor = PostProcessor() post_processor.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, )