Source code for gigl.src.validation_check.config_validator

import argparse
from typing import Optional

from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
    GiglResourceConfigWrapper,
)
from gigl.src.validation_check.libs.frozen_config_path_checks import (
    assert_preprocessed_metadata_exists,
    assert_split_generator_output_exists,
    assert_subgraph_sampler_output_exists,
    assert_trained_model_exists,
)
from gigl.src.validation_check.libs.name_checks import (
    check_if_kfp_pipeline_job_name_valid,
)
from gigl.src.validation_check.libs.resource_config_checks import (
    check_if_inferencer_resource_config_valid,
    check_if_preprocessor_resource_config_valid,
    check_if_shared_resource_config_valid,
    check_if_split_generator_resource_config_valid,
    check_if_subgraph_sampler_resource_config_valid,
    check_if_trainer_resource_config_valid,
)
from gigl.src.validation_check.libs.template_config_checks import (
    check_if_data_preprocessor_config_cls_valid,
    check_if_graph_metadata_valid,
    check_if_inferencer_cls_valid,
    check_if_post_processor_cls_valid,
    check_if_preprocessed_metadata_valid,
    check_if_split_generator_config_valid,
    check_if_subgraph_sampler_config_valid,
    check_if_task_metadata_valid,
    check_if_trainer_cls_valid,
    check_pipeline_has_valid_start_and_stop_flags,
)
from snapchat.research.gbml import gbml_config_pb2
from snapchat.research.gbml.gigl_resource_config_pb2 import GiglResourceConfig

[docs] START_STOP_COMPONENT_TO_CLS_CHECKS_MAP = { # TODO: (svij-sc) Add checks as needed, otherwise we default to below anyways (GiGLComponents.SubgraphSampler.value, GiGLComponents.SubgraphSampler.value): [ check_if_graph_metadata_valid, check_if_task_metadata_valid, check_if_subgraph_sampler_config_valid, ], }
[docs] START_COMPONENT_TO_CLS_CHECKS_MAP = { GiGLComponents.ConfigPopulator.value: [ check_if_graph_metadata_valid, check_if_task_metadata_valid, check_if_data_preprocessor_config_cls_valid, check_if_subgraph_sampler_config_valid, check_if_split_generator_config_valid, check_if_trainer_cls_valid, check_if_inferencer_cls_valid, check_if_post_processor_cls_valid, ], GiGLComponents.DataPreprocessor.value: [ check_if_graph_metadata_valid, check_if_task_metadata_valid, check_if_data_preprocessor_config_cls_valid, check_if_subgraph_sampler_config_valid, check_if_split_generator_config_valid, check_if_trainer_cls_valid, check_if_inferencer_cls_valid, check_if_post_processor_cls_valid, ], GiGLComponents.SubgraphSampler.value: [ check_if_graph_metadata_valid, check_if_task_metadata_valid, check_if_preprocessed_metadata_valid, check_if_subgraph_sampler_config_valid, check_if_split_generator_config_valid, check_if_trainer_cls_valid, check_if_inferencer_cls_valid, check_if_post_processor_cls_valid, ], GiGLComponents.SplitGenerator.value: [ check_if_graph_metadata_valid, check_if_task_metadata_valid, check_if_preprocessed_metadata_valid, check_if_split_generator_config_valid, check_if_trainer_cls_valid, check_if_inferencer_cls_valid, check_if_post_processor_cls_valid, ], GiGLComponents.Trainer.value: [ check_if_graph_metadata_valid, check_if_task_metadata_valid, check_if_preprocessed_metadata_valid, check_if_trainer_cls_valid, check_if_inferencer_cls_valid, check_if_post_processor_cls_valid, ], GiGLComponents.Inferencer.value: [ check_if_graph_metadata_valid, check_if_task_metadata_valid, check_if_preprocessed_metadata_valid, check_if_inferencer_cls_valid, check_if_post_processor_cls_valid, ], GiGLComponents.PostProcessor.value: [ check_if_graph_metadata_valid, check_if_task_metadata_valid, check_if_post_processor_cls_valid, ], }
[docs] START_COMPONENT_TO_ASSET_CHECKS_MAP = { GiGLComponents.SubgraphSampler.value: [ assert_preprocessed_metadata_exists, ], GiGLComponents.SplitGenerator.value: [ assert_preprocessed_metadata_exists, assert_subgraph_sampler_output_exists, ], GiGLComponents.Trainer.value: [ assert_preprocessed_metadata_exists, assert_subgraph_sampler_output_exists, assert_split_generator_output_exists, ], GiGLComponents.Inferencer.value: [ assert_preprocessed_metadata_exists, assert_subgraph_sampler_output_exists, assert_trained_model_exists, ], }
[docs] START_STOP_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP = { (GiGLComponents.SubgraphSampler.value, GiGLComponents.SubgraphSampler.value): [ check_if_shared_resource_config_valid, check_if_subgraph_sampler_resource_config_valid, ], }
[docs] START_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP = { GiGLComponents.ConfigPopulator.value: [ check_if_shared_resource_config_valid, ], GiGLComponents.DataPreprocessor.value: [ check_if_shared_resource_config_valid, check_if_preprocessor_resource_config_valid, check_if_subgraph_sampler_resource_config_valid, check_if_split_generator_resource_config_valid, check_if_trainer_resource_config_valid, check_if_inferencer_resource_config_valid, ], GiGLComponents.SubgraphSampler.value: [ check_if_shared_resource_config_valid, check_if_subgraph_sampler_resource_config_valid, check_if_split_generator_resource_config_valid, check_if_trainer_resource_config_valid, check_if_inferencer_resource_config_valid, ], GiGLComponents.SplitGenerator.value: [ check_if_shared_resource_config_valid, check_if_split_generator_resource_config_valid, check_if_trainer_resource_config_valid, check_if_inferencer_resource_config_valid, ], GiGLComponents.Trainer.value: [ check_if_shared_resource_config_valid, check_if_trainer_resource_config_valid, check_if_inferencer_resource_config_valid, ], GiGLComponents.Inferencer.value: [ check_if_shared_resource_config_valid, check_if_inferencer_resource_config_valid, ], GiGLComponents.PostProcessor.value: [ check_if_shared_resource_config_valid, ], }
[docs] logger = Logger()
[docs] def kfp_validation_checks( job_name: str, task_config_uri: Uri, start_at: str, resource_config_uri: Uri, stop_after: Optional[str] = None, ) -> None: # check if job_name is valid check_if_kfp_pipeline_job_name_valid(job_name=job_name) # check if start_at and stop_after aligns with glt backend use check_pipeline_has_valid_start_and_stop_flags( start_at=start_at, stop_after=stop_after, task_config_uri=task_config_uri.uri ) proto_utils = ProtoUtils() gbml_config_pb: gbml_config_pb2.GbmlConfig = proto_utils.read_proto_from_yaml( uri=task_config_uri, proto_cls=gbml_config_pb2.GbmlConfig ) resource_config_wrapper: GiglResourceConfigWrapper = get_resource_config( resource_config_uri=resource_config_uri ) resource_config_pb: GiglResourceConfig = resource_config_wrapper.resource_config # check user defined classes and their runtime args if ( stop_after is not None and (start_at, stop_after) in START_STOP_COMPONENT_TO_CLS_CHECKS_MAP ): for cls_check in START_STOP_COMPONENT_TO_CLS_CHECKS_MAP[(start_at, stop_after)]: cls_check(gbml_config_pb=gbml_config_pb) else: for cls_check in START_COMPONENT_TO_CLS_CHECKS_MAP.get(start_at, []): cls_check(gbml_config_pb=gbml_config_pb) # check the existence of needed assets for asset_check in START_COMPONENT_TO_ASSET_CHECKS_MAP.get(start_at, []): asset_check(gbml_config_pb=gbml_config_pb) # check if user-provided resource config is valid if ( stop_after is not None and (start_at, stop_after) in START_STOP_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP ): for resource_config_check in START_STOP_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP[ (start_at, stop_after) ]: resource_config_check(resource_config_pb=resource_config_pb) else: for resource_config_check in START_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP.get( start_at, [] ): resource_config_check(resource_config_pb=resource_config_pb) # check if trained model file exist when skipping training if gbml_config_pb.shared_config.should_skip_training == True: assert_trained_model_exists(gbml_config_pb=gbml_config_pb) logger.info("[✅ SUCCESS] All checks passed successfully.")
if __name__ == "__main__":
[docs] parser = argparse.ArgumentParser( description="Checks if config files and assets are valid for a GiGL pipeline run." )
parser.add_argument( "--job_name", type=str, help="Unique identifier for the job name", ) parser.add_argument( "--task_config_uri", type=str, help="GCS URI to template_or_frozen_config_uri", ) parser.add_argument( "--start_at", type=str, help="Specify the component where to start the pipeline", ) parser.add_argument( "--stop_after", type=str, help="Specify the component where to stop the pipeline", ) parser.add_argument( "--resource_config_uri", type=str, help="Runtime argument for resource and env specifications of each component", ) args = parser.parse_args() kfp_validation_checks( job_name=args.job_name, task_config_uri=UriFactory.create_uri(args.task_config_uri), start_at=args.start_at, resource_config_uri=UriFactory.create_uri(args.resource_config_uri), stop_after=args.stop_after, )