import re
from typing import Any, Dict, Optional
from gigl.common import UriFactory
from gigl.common.logger import Logger
from gigl.common.utils import os_utils
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.src.common.constants.components import GLT_BACKEND_UNSUPPORTED_COMPONENTS
from gigl.src.common.translators.gbml_protos_translator import GbmlProtosTranslator
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.subgraph_sampling_strategy import (
SubgraphSamplingStrategyPbWrapper,
)
from gigl.src.common.types.pb_wrappers.task_metadata import TaskMetadataPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from gigl.src.data_preprocessor.lib.data_preprocessor_config import (
DataPreprocessorConfig,
)
from gigl.src.inference.v1.lib.base_inferencer import BaseInferencer
from gigl.src.post_process.lib.base_post_processor import BasePostProcessor
from gigl.src.training.v1.lib.base_trainer import BaseTrainer
from gigl.src.validation_check.libs.utils import assert_proto_field_value_is_truthy
from snapchat.research.gbml import gbml_config_pb2, preprocessed_metadata_pb2
[docs]
def check_if_kfp_pipeline_job_name_valid(job_name: str) -> None:
"""
Check if kfp pipeline job name valid. It is used to start spark cluster and must match pattern.
The kfp pipeline job name is also used to generate AppliedTaskIdentifier for each component.
"""
logger.info("Config validation check: if job_name valid.")
if not bool(re.match(r"^(?:[a-z](?:[-_a-z0-9]{0,49}[a-z0-9])?)$", job_name)):
raise ValueError(
f"Invalid 'job_name'. Only lowercase letters, numbers, and dashes are allowed. "
f"The value must start with lowercase letter or number and end with a lowercase letter or number. "
f"'job_name' provided: {job_name} ."
)
[docs]
def check_pipeline_has_valid_start_and_stop_flags(
start_at: str,
stop_after: Optional[str],
task_config_uri: str,
) -> None:
"""
Check if start_at and stop_after are valid with current static (gigl) or dynamic (glt) backend
"""
gbml_config_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
gbml_config_uri=UriFactory.create_uri(task_config_uri)
)
components = [start_at] if stop_after is None else [start_at, stop_after]
for component in components:
if gbml_config_wrapper.should_use_glt_backend:
if component in GLT_BACKEND_UNSUPPORTED_COMPONENTS:
raise ValueError(
f"Invalid component {component} for GLT Backend"
f"GLT Backend does not support components {GLT_BACKEND_UNSUPPORTED_COMPONENTS}."
)
[docs]
def check_if_runtime_args_all_str(args_name: str, runtime_args: Dict[str, Any]) -> None:
"""
Check if all values of the given runtime arguements are string.
"""
for arg_key, arg_value in runtime_args.items():
if type(arg_value) is not str:
raise ValueError(
f"Invalid type for runtime arguements under {args_name}, should be string. "
f"Got {arg_value} with type {type(arg_value)} for {arg_key}."
)
[docs]
def check_if_data_preprocessor_config_cls_valid(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> None:
"""
Check if dataPreprocessorArgs are all string.
Check if dataPreprocessorConfigClsPath is valid and importable.
"""
logger.info(
"Config validation check: if dataPreprocessorConfigClsPath and its args are valid."
)
data_preprocessor_config_cls_path = (
gbml_config_pb.dataset_config.data_preprocessor_config.data_preprocessor_config_cls_path
)
runtime_args: Dict[str, str] = dict(
gbml_config_pb.dataset_config.data_preprocessor_config.data_preprocessor_args
)
check_if_runtime_args_all_str(
args_name="dataPreprocessorArgs", runtime_args=runtime_args
)
try:
data_preprocessor_cls = os_utils.import_obj(data_preprocessor_config_cls_path)
data_preprocessor_config: DataPreprocessorConfig = data_preprocessor_cls(
**runtime_args
)
assert isinstance(data_preprocessor_config, DataPreprocessorConfig)
except Exception as e:
raise ValueError(
f"Invalid 'dataPreprocessorConfigClsPath' in frozen config: datasetConfig - dataPreprocessorConfig. "
f"'dataPreprocessorConfigClsPath' provided: {data_preprocessor_config_cls_path}. "
f"Error: {e}"
)
[docs]
def check_if_trainer_cls_valid(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> None:
"""
Check if trainerArgs are all string.
Check if trainerClsPath is valid and importable.
"""
logger.info("Config validation check: if trainerClsPath and its args are valid.")
gbml_config_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
if gbml_config_wrapper.should_use_glt_backend:
logger.warning(
"Skipping trainer class validation as GLT Backend is not implemented yet. "
+ "Trainer class may actually be a path to a script so, the paradigm is different."
+ "This is temporary to unblock testing and will be refactored in the future."
)
return
trainer_cls_path = gbml_config_pb.trainer_config.trainer_cls_path
runtime_args: Dict[str, str] = dict(gbml_config_pb.trainer_config.trainer_args)
check_if_runtime_args_all_str(args_name="trainerArgs", runtime_args=runtime_args)
try:
trainer_cls = os_utils.import_obj(trainer_cls_path)
trainer: BaseTrainer = trainer_cls(**runtime_args)
assert isinstance(trainer, BaseTrainer)
except Exception as e:
raise ValueError(
f"Invalid 'trainerClsPath' in frozen config: trainerConfig - trainerClsPath. "
f"'trainerClsPath' provided: {trainer_cls_path}. "
f"Error: {e}"
)
[docs]
def check_if_inferencer_cls_valid(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> None:
"""
Check if inferencerArgs are all string.
Check if inferencerClsPath is valid and importable.
"""
logger.info("Config validation check: if inferencerClsPath and its args are valid.")
inferencer_cls_path = gbml_config_pb.inferencer_config.inferencer_cls_path
runtime_args: Dict[str, str] = dict(
gbml_config_pb.inferencer_config.inferencer_args
)
check_if_runtime_args_all_str(args_name="inferencerArgs", runtime_args=runtime_args)
gbml_config_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
if gbml_config_wrapper.should_use_glt_backend:
logger.warning(
"Skipping inferencer class validation as GLT Backend is enabled. "
+ "Inferencer class may actually be a path to a script so, the paradigm is different."
+ "This is temporary to unblock testing and will be refactored in the future."
)
return
try:
inferencer_cls = os_utils.import_obj(inferencer_cls_path)
inferencer_instance: BaseInferencer = inferencer_cls(**runtime_args)
assert isinstance(inferencer_instance, BaseInferencer)
except Exception as e:
raise ValueError(
f"Invalid 'inferencerClsPath' in frozen config: inferencerConfig - inferencerClsPath. "
f"'inferencerClsPath' provided: {inferencer_cls_path}. "
f"Error: {e}"
)
[docs]
def check_if_split_generator_config_valid(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> None:
"""
Check if splitGeneratorConfig is valid.
"""
logger.info("Config validation check: if splitGeneratorConfig is valid.")
gbml_config_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
if gbml_config_wrapper.should_use_glt_backend:
logger.warning(
"Skipping splitGeneratorConfig validation as GLT Backend is enabled."
)
return
assigner_cls_path = (
gbml_config_pb.dataset_config.split_generator_config.assigner_cls_path
)
split_strategy_cls_path = (
gbml_config_pb.dataset_config.split_generator_config.split_strategy_cls_path
)
if not assigner_cls_path or not split_strategy_cls_path:
raise ValueError(
"Invalid class paths or class paths not provided in splitGeneratorConfig."
)
assigner_args = dict(
gbml_config_pb.dataset_config.split_generator_config.assigner_args
)
check_if_runtime_args_all_str(args_name="assignerArgs", runtime_args=assigner_args)
[docs]
def check_if_subgraph_sampler_config_valid(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> None:
"""
Check if subgraphSamplerConfig is valid.
"""
logger.info("Config validation check: if subgraphSamplerConfig is valid.")
gbml_config_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
if gbml_config_wrapper.should_use_glt_backend:
logger.warning(
"Skipping subgraph sampler (SGS) validation check since GLT Backend is being used."
)
return
subgraph_sampler_config = gbml_config_pb.dataset_config.subgraph_sampler_config
if subgraph_sampler_config.HasField("subgraph_sampling_strategy"):
subgraph_sampling_strategy_pb_wrapper = SubgraphSamplingStrategyPbWrapper(
subgraph_sampler_config.subgraph_sampling_strategy
)
subgraph_sampling_strategy_pb_wrapper.validate_dags(
graph_metadata_pb=gbml_config_pb.graph_metadata,
task_metadata_pb=gbml_config_pb.task_metadata,
)
else:
num_hops = subgraph_sampler_config.num_hops
num_neighbors_to_sample = subgraph_sampler_config.num_neighbors_to_sample
if num_hops <= 0:
raise ValueError("Invalid numHops in subgraphSamplerConfig.")
if num_neighbors_to_sample <= 0:
raise ValueError("Invalid numNeighborsToSample in subgraphSamplerConfig.")
num_positive_samples = subgraph_sampler_config.num_positive_samples
num_user_defined_positive_samples = (
subgraph_sampler_config.num_user_defined_positive_samples
)
num_user_defined_negative_samples = (
subgraph_sampler_config.num_user_defined_negative_samples
)
num_max_training_samples_to_output = (
subgraph_sampler_config.num_max_training_samples_to_output
)
if num_positive_samples < 0:
raise ValueError("Invalid numPositiveSamples in subgraphSamplerConfig.")
if num_user_defined_positive_samples < 0:
raise ValueError(
"Invalid numUserDefinedPositiveSamples in subgraphSamplerConfig."
)
if num_user_defined_negative_samples < 0:
raise ValueError(
"Invalid numUserDefinedNegativeSamples in subgraphSamplerConfig."
)
if num_user_defined_positive_samples > 0 and num_positive_samples > 0:
raise ValueError(
"Can provide either num_positive_samples, or num_user_defined_positive_samples; not both."
)
assert (
sum([num_user_defined_positive_samples, num_positive_samples]) > 0
), "Must provide either num_positive_samples, or num_user_defined_positive_samples."
if num_max_training_samples_to_output < 0:
raise ValueError(
"Invalid numMaxTrainingSamplesToOutput in subgraphSamplerConfig."
)
[docs]
def check_if_post_processor_cls_valid(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> None:
"""
Check if postProcessorArgs are all string.
Check if postProcessorClsPath is valid and importable.
"""
logger.info(
"Config validation check: if postProcessorClsPath and its args are valid."
)
post_processor_cls_path = (
gbml_config_pb.post_processor_config.post_processor_cls_path
)
if not post_processor_cls_path:
logger.info(
"No post processor class provided - skipping checks for post processor"
)
return
runtime_args: Dict[str, str] = dict(
gbml_config_pb.post_processor_config.post_processor_args
)
check_if_runtime_args_all_str(
args_name="postProcessorArgs", runtime_args=runtime_args
)
try:
post_processor_cls = os_utils.import_obj(post_processor_cls_path)
post_processor_instance: BasePostProcessor = post_processor_cls(**runtime_args)
assert isinstance(post_processor_instance, BasePostProcessor)
except Exception as e:
raise ValueError(
f"Invalid 'postProcessorClsPath' in frozen config and/or postProcessorArgs could not successfully "
f"initialize the 'postProcessorClsPath' provided: {post_processor_cls_path}. "
f"Error: {e}"
)