import argparse
from typing import Optional
from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
GiglResourceConfigWrapper,
)
from gigl.src.validation_check.libs.frozen_config_path_checks import (
assert_preprocessed_metadata_exists,
assert_split_generator_output_exists,
assert_subgraph_sampler_output_exists,
assert_trained_model_exists,
)
from gigl.src.validation_check.libs.gbml_and_resource_config_compatibility_checks import (
check_inferencer_graph_store_compatibility,
check_trainer_graph_store_compatibility,
)
from gigl.src.validation_check.libs.name_checks import (
check_if_kfp_pipeline_job_name_valid,
)
from gigl.src.validation_check.libs.resource_config_checks import (
check_if_inferencer_resource_config_valid,
check_if_preprocessor_resource_config_valid,
check_if_shared_resource_config_valid,
check_if_split_generator_resource_config_valid,
check_if_subgraph_sampler_resource_config_valid,
check_if_trainer_resource_config_valid,
)
from gigl.src.validation_check.libs.template_config_checks import (
check_if_data_preprocessor_config_cls_valid,
check_if_graph_metadata_valid,
check_if_inferencer_cls_valid,
check_if_post_processor_cls_valid,
check_if_preprocessed_metadata_valid,
check_if_split_generator_config_valid,
check_if_subgraph_sampler_config_valid,
check_if_task_metadata_valid,
check_if_trainer_cls_valid,
check_pipeline_has_valid_start_and_stop_flags,
)
from snapchat.research.gbml import gbml_config_pb2
from snapchat.research.gbml.gigl_resource_config_pb2 import GiglResourceConfig
[docs]
START_STOP_COMPONENT_TO_CLS_CHECKS_MAP = {
# TODO: (svij-sc) Add checks as needed, otherwise we default to below anyways
(GiGLComponents.SubgraphSampler.value, GiGLComponents.SubgraphSampler.value): [
check_if_graph_metadata_valid,
check_if_task_metadata_valid,
check_if_subgraph_sampler_config_valid,
],
}
[docs]
START_COMPONENT_TO_CLS_CHECKS_MAP = {
GiGLComponents.ConfigPopulator.value: [
check_if_graph_metadata_valid,
check_if_task_metadata_valid,
check_if_data_preprocessor_config_cls_valid,
check_if_subgraph_sampler_config_valid,
check_if_split_generator_config_valid,
check_if_trainer_cls_valid,
check_if_inferencer_cls_valid,
check_if_post_processor_cls_valid,
],
GiGLComponents.DataPreprocessor.value: [
check_if_graph_metadata_valid,
check_if_task_metadata_valid,
check_if_data_preprocessor_config_cls_valid,
check_if_subgraph_sampler_config_valid,
check_if_split_generator_config_valid,
check_if_trainer_cls_valid,
check_if_inferencer_cls_valid,
check_if_post_processor_cls_valid,
],
GiGLComponents.SubgraphSampler.value: [
check_if_graph_metadata_valid,
check_if_task_metadata_valid,
check_if_preprocessed_metadata_valid,
check_if_subgraph_sampler_config_valid,
check_if_split_generator_config_valid,
check_if_trainer_cls_valid,
check_if_inferencer_cls_valid,
check_if_post_processor_cls_valid,
],
GiGLComponents.SplitGenerator.value: [
check_if_graph_metadata_valid,
check_if_task_metadata_valid,
check_if_preprocessed_metadata_valid,
check_if_split_generator_config_valid,
check_if_trainer_cls_valid,
check_if_inferencer_cls_valid,
check_if_post_processor_cls_valid,
],
GiGLComponents.Trainer.value: [
check_if_graph_metadata_valid,
check_if_task_metadata_valid,
check_if_preprocessed_metadata_valid,
check_if_trainer_cls_valid,
check_if_inferencer_cls_valid,
check_if_post_processor_cls_valid,
],
GiGLComponents.Inferencer.value: [
check_if_graph_metadata_valid,
check_if_task_metadata_valid,
check_if_preprocessed_metadata_valid,
check_if_inferencer_cls_valid,
check_if_post_processor_cls_valid,
],
GiGLComponents.PostProcessor.value: [
check_if_graph_metadata_valid,
check_if_task_metadata_valid,
check_if_post_processor_cls_valid,
],
}
[docs]
START_COMPONENT_TO_ASSET_CHECKS_MAP = {
GiGLComponents.SubgraphSampler.value: [
assert_preprocessed_metadata_exists,
],
GiGLComponents.SplitGenerator.value: [
assert_preprocessed_metadata_exists,
assert_subgraph_sampler_output_exists,
],
GiGLComponents.Trainer.value: [
assert_preprocessed_metadata_exists,
assert_subgraph_sampler_output_exists,
assert_split_generator_output_exists,
],
GiGLComponents.Inferencer.value: [
assert_preprocessed_metadata_exists,
assert_subgraph_sampler_output_exists,
assert_trained_model_exists,
],
}
[docs]
START_STOP_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP = {
(GiGLComponents.SubgraphSampler.value, GiGLComponents.SubgraphSampler.value): [
check_if_shared_resource_config_valid,
check_if_subgraph_sampler_resource_config_valid,
],
}
[docs]
START_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP = {
GiGLComponents.ConfigPopulator.value: [
check_if_shared_resource_config_valid,
check_if_preprocessor_resource_config_valid,
check_if_subgraph_sampler_resource_config_valid,
check_if_split_generator_resource_config_valid,
check_if_trainer_resource_config_valid,
check_if_inferencer_resource_config_valid,
],
GiGLComponents.DataPreprocessor.value: [
check_if_shared_resource_config_valid,
check_if_preprocessor_resource_config_valid,
check_if_subgraph_sampler_resource_config_valid,
check_if_split_generator_resource_config_valid,
check_if_trainer_resource_config_valid,
check_if_inferencer_resource_config_valid,
],
GiGLComponents.SubgraphSampler.value: [
check_if_shared_resource_config_valid,
check_if_subgraph_sampler_resource_config_valid,
check_if_split_generator_resource_config_valid,
check_if_trainer_resource_config_valid,
check_if_inferencer_resource_config_valid,
],
GiGLComponents.SplitGenerator.value: [
check_if_shared_resource_config_valid,
check_if_split_generator_resource_config_valid,
check_if_trainer_resource_config_valid,
check_if_inferencer_resource_config_valid,
],
GiGLComponents.Trainer.value: [
check_if_shared_resource_config_valid,
check_if_trainer_resource_config_valid,
check_if_inferencer_resource_config_valid,
],
GiGLComponents.Inferencer.value: [
check_if_shared_resource_config_valid,
check_if_inferencer_resource_config_valid,
],
GiGLComponents.PostProcessor.value: [
check_if_shared_resource_config_valid,
],
}
# Resource config checks to skip when using live subgraph sampling backend
[docs]
RESOURCE_CONFIG_CHECKS_TO_SKIP_WITH_LIVE_SGS_BACKEND = [
check_if_subgraph_sampler_resource_config_valid,
check_if_split_generator_resource_config_valid,
]
# Map of start components to graph store compatibility checks to run
# Only run trainer checks when starting at or before Trainer
# Only run inferencer checks when starting at or before Inferencer
[docs]
START_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS = {
GiGLComponents.ConfigPopulator.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.DataPreprocessor.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.SubgraphSampler.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.SplitGenerator.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.Trainer.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.Inferencer.value: [
check_inferencer_graph_store_compatibility,
],
# PostProcessor doesn't need graph store compatibility checks
}
# Map of (start, stop) component tuples to graph store compatibility checks
[docs]
STOP_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS_TO_SKIP = {
GiGLComponents.Trainer.value: [
check_inferencer_graph_store_compatibility,
],
}
def _run_gbml_and_resource_config_compatibility_checks(
start_at: str,
stop_after: Optional[str],
gbml_config_pb_wrapper: GbmlConfigPbWrapper,
resource_config_wrapper: GiglResourceConfigWrapper,
) -> None:
"""
Run compatibility checks between GbmlConfig and GiglResourceConfig.
These checks verify that graph store mode configurations are consistent
across both the template config (GbmlConfig) and resource config (GiglResourceConfig).
Args:
start_at: The component to start at.
stop_after: Optional component to stop after.
gbml_config_pb_wrapper: The GbmlConfig wrapper (template config).
resource_config_wrapper: The GiglResourceConfig wrapper (resource config).
"""
# Get the appropriate compatibility checks based on start/stop components
compatibility_checks = set(
START_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS.get(start_at, [])
)
if stop_after in STOP_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS_TO_SKIP:
for skipped_check in STOP_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS_TO_SKIP[
stop_after
]:
compatibility_checks.discard(skipped_check)
for check in compatibility_checks:
check(
gbml_config_pb_wrapper=gbml_config_pb_wrapper,
resource_config_wrapper=resource_config_wrapper,
)
[docs]
def kfp_validation_checks(
job_name: str,
task_config_uri: Uri,
start_at: str,
resource_config_uri: Uri,
stop_after: Optional[str] = None,
) -> None:
# check if job_name is valid
check_if_kfp_pipeline_job_name_valid(job_name=job_name)
# check if start_at and stop_after aligns with live subgraph sampling backend use
check_pipeline_has_valid_start_and_stop_flags(
start_at=start_at, stop_after=stop_after, task_config_uri=task_config_uri.uri
)
gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
gbml_config_uri=task_config_uri
)
gbml_config_pb: gbml_config_pb2.GbmlConfig = gbml_config_pb_wrapper.gbml_config_pb
should_use_live_sgs_backend = gbml_config_pb_wrapper.should_use_glt_backend
resource_config_wrapper: GiglResourceConfigWrapper = get_resource_config(
resource_config_uri=resource_config_uri
)
resource_config_pb: GiglResourceConfig = resource_config_wrapper.resource_config
# check user defined classes and their runtime args
if (
stop_after is not None
and (start_at, stop_after) in START_STOP_COMPONENT_TO_CLS_CHECKS_MAP
):
cls_checks = START_STOP_COMPONENT_TO_CLS_CHECKS_MAP[(start_at, stop_after)]
else:
cls_checks = START_COMPONENT_TO_CLS_CHECKS_MAP.get(start_at, [])
for cls_check in cls_checks:
cls_check(gbml_config_pb=gbml_config_pb)
# check the existence of needed assets
for asset_check in START_COMPONENT_TO_ASSET_CHECKS_MAP.get(start_at, []):
asset_check(gbml_config_pb=gbml_config_pb)
# check if user-provided resource config is valid
# Skip SGS and split resource config checks when using live subgraph sampling backend
resource_config_checks_to_skip = (
RESOURCE_CONFIG_CHECKS_TO_SKIP_WITH_LIVE_SGS_BACKEND
if should_use_live_sgs_backend
else []
)
if (
stop_after is not None
and (start_at, stop_after) in START_STOP_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP
):
resource_config_checks = START_STOP_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP[
(start_at, stop_after)
]
else:
resource_config_checks = START_COMPONENT_TO_RESOURCE_CONFIG_CHECKS_MAP.get(
start_at, []
)
for resource_config_check in resource_config_checks:
if resource_config_check not in resource_config_checks_to_skip:
resource_config_check(resource_config_pb=resource_config_pb)
else:
logger.info(
f"Skipping resource config check {resource_config_check.__name__} because we are using live subgraph sampling backend."
)
# check compatibility between template config and resource config for graph store mode
# These checks ensure that if graph store mode is enabled in one config, it's also enabled in the other
_run_gbml_and_resource_config_compatibility_checks(
start_at=start_at,
stop_after=stop_after,
gbml_config_pb_wrapper=gbml_config_pb_wrapper,
resource_config_wrapper=resource_config_wrapper,
)
# check if trained model file exist when skipping training
if gbml_config_pb.shared_config.should_skip_training == True:
assert_trained_model_exists(gbml_config_pb=gbml_config_pb)
logger.info("[✅ SUCCESS] All checks passed successfully.")
if __name__ == "__main__":
[docs]
parser = argparse.ArgumentParser(
description="Checks if config files and assets are valid for a GiGL pipeline run."
)
parser.add_argument(
"--job_name",
type=str,
help="Unique identifier for the job name",
)
parser.add_argument(
"--task_config_uri",
type=str,
help="GCS URI to template_or_frozen_config_uri",
)
parser.add_argument(
"--start_at",
type=str,
help="Specify the component where to start the pipeline",
)
parser.add_argument(
"--stop_after",
type=str,
help="Specify the component where to stop the pipeline",
)
parser.add_argument(
"--resource_config_uri",
type=str,
help="Runtime argument for resource and env specifications of each component",
)
args = parser.parse_args()
kfp_validation_checks(
job_name=args.job_name,
task_config_uri=UriFactory.create_uri(args.task_config_uri),
start_at=args.start_at,
resource_config_uri=UriFactory.create_uri(args.resource_config_uri),
stop_after=args.stop_after,
)