from typing import Optional, Set
from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.task_metadata import TaskMetadataPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from gigl.src.common.utils.file_loader import FileLoader
from snapchat.research.gbml import gbml_config_pb2
[docs]
file_loader = FileLoader()
[docs]
def assert_asset_exists(
resource_name: str,
uri: Uri,
file_name_suffix: Optional[str] = None,
) -> None:
logger.info(
f"Config validation check: if {resource_name} at {uri}*{file_name_suffix} exists."
)
if file_loader.count_assets(uri_prefix=uri, suffix=file_name_suffix) < 1:
raise ValueError(
f"Required resource does not exist, "
f"file path specified in frozen config: sharedConfig - {resource_name}. "
f"'{resource_name}' provided: {uri} "
)
[docs]
def assert_trained_model_exists(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> None:
"""
Check if trained model file exists.
"""
gbml_config_pb_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
if gbml_config_pb_wrapper.should_use_glt_backend:
logger.warning(
"Skipping trained model check since GLT Backend is being used."
+ "Currently it is not expected that model be piped in through gigl specific configs. "
+ "This will be updated in the future."
)
return
assert_asset_exists(
resource_name="trainedModelUri",
uri=UriFactory.create_uri(
gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri
),
)
[docs]
def assert_split_generator_output_exists(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> None:
"""
Check if split generator output files exist.
"""
gbml_config_pb_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
if gbml_config_pb_wrapper.should_use_glt_backend:
logger.warning(
"Skipping split generator output check since GLT Backend is being used."
)
return
task_metadata_type = (
gbml_config_pb_wrapper.task_metadata_pb_wrapper.task_metadata_type
)
dataset_metadata_pb = gbml_config_pb.shared_config.dataset_metadata
if task_metadata_type == TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK:
# node types for which random negative samples are generated
# Only target node types are considered for random negative samples in Split Genenator
random_negative_node_types = gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_supervision_edge_node_types(
should_include_src_nodes=False,
should_include_dst_nodes=True,
)
if not gbml_config_pb.shared_config.should_skip_training:
assert_asset_exists(
resource_name="trainMainDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.node_anchor_based_link_prediction_dataset.train_main_data_uri
),
file_name_suffix=".tfrecord",
)
assert_asset_exists(
resource_name="valMainDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.node_anchor_based_link_prediction_dataset.val_main_data_uri
),
file_name_suffix=".tfrecord",
)
for node_type in random_negative_node_types:
assert_asset_exists(
resource_name="trainRandomNegativeDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.node_anchor_based_link_prediction_dataset.train_node_type_to_random_negative_data_uri[
node_type
]
),
file_name_suffix=".tfrecord",
)
assert_asset_exists(
resource_name="valRandomNegativeDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.node_anchor_based_link_prediction_dataset.val_node_type_to_random_negative_data_uri[
node_type
]
),
file_name_suffix=".tfrecord",
)
assert_asset_exists(
resource_name="testMainDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.node_anchor_based_link_prediction_dataset.test_main_data_uri
),
file_name_suffix=".tfrecord",
)
for node_type in random_negative_node_types:
assert_asset_exists(
resource_name="testRandomNegativeDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.node_anchor_based_link_prediction_dataset.test_node_type_to_random_negative_data_uri[
node_type
]
),
file_name_suffix=".tfrecord",
)
elif task_metadata_type == TaskMetadataType.NODE_BASED_TASK:
if not gbml_config_pb.shared_config.should_skip_training:
assert_asset_exists(
resource_name="trainDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.supervised_node_classification_dataset.train_data_uri
),
file_name_suffix=".tfrecord",
)
assert_asset_exists(
resource_name="valDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.supervised_node_classification_dataset.val_data_uri
),
file_name_suffix=".tfrecord",
)
assert_asset_exists(
resource_name="testDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.supervised_node_classification_dataset.test_data_uri
),
file_name_suffix=".tfrecord",
)
elif task_metadata_type == TaskMetadataType.LINK_BASED_TASK:
if not gbml_config_pb.shared_config.should_skip_training:
assert_asset_exists(
resource_name="trainDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.supervised_link_based_task_dataset.train_data_uri
),
file_name_suffix=".tfrecord",
)
assert_asset_exists(
resource_name="valDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.supervised_link_based_task_dataset.val_data_uri
),
file_name_suffix=".tfrecord",
)
assert_asset_exists(
resource_name="testDataUri",
uri=UriFactory.create_uri(
dataset_metadata_pb.supervised_link_based_task_dataset.test_data_uri
),
file_name_suffix=".tfrecord",
)
[docs]
def assert_subgraph_sampler_output_exists(
gbml_config_pb: gbml_config_pb2.GbmlConfig,
) -> None:
"""
Check if subgraph sampler output files exist.
"""
gbml_config_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
if gbml_config_wrapper.should_use_glt_backend:
logger.warning(
"Skipping subgraph sampler output check since GLT Backend is being used."
)
return
task_metadata_wrapper = TaskMetadataPbWrapper(gbml_config_pb.task_metadata)
flattened_graph_metadata_pb = gbml_config_pb.shared_config.flattened_graph_metadata
if (
task_metadata_wrapper.task_metadata_type
== TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK
):
assert_asset_exists(
resource_name="tfrecordUriPrefix",
uri=UriFactory.create_uri(
flattened_graph_metadata_pb.node_anchor_based_link_prediction_output.tfrecord_uri_prefix
),
file_name_suffix=".tfrecord",
)
assert isinstance(
task_metadata_wrapper.task_metadata,
gbml_config_pb2.GbmlConfig.TaskMetadata.NodeAnchorBasedLinkPredictionTaskMetadata,
)
random_negative_node_types: Set[str] = set()
for (
supervision_edge_type
) in task_metadata_wrapper.task_metadata.supervision_edge_types:
random_negative_node_types.add(supervision_edge_type.src_node_type)
random_negative_node_types.add(supervision_edge_type.dst_node_type)
for node_type in random_negative_node_types:
assert_asset_exists(
resource_name="randomNegativeTfrecordUriPrefix",
uri=UriFactory.create_uri(
flattened_graph_metadata_pb.node_anchor_based_link_prediction_output.node_type_to_random_negative_tfrecord_uri_prefix[
node_type
]
),
file_name_suffix=".tfrecord",
)
elif task_metadata_wrapper.task_metadata_type == TaskMetadataType.NODE_BASED_TASK:
assert_asset_exists(
resource_name="labeledTfrecordUriPrefix",
uri=UriFactory.create_uri(
flattened_graph_metadata_pb.supervised_node_classification_output.labeled_tfrecord_uri_prefix
),
file_name_suffix=".tfrecord",
)
assert_asset_exists(
resource_name="unlabeledTfrecordUriPrefix",
uri=UriFactory.create_uri(
flattened_graph_metadata_pb.supervised_node_classification_output.unlabeled_tfrecord_uri_prefix
),
file_name_suffix=".tfrecord",
)
elif task_metadata_wrapper.task_metadata_type == TaskMetadataType.LINK_BASED_TASK:
assert_asset_exists(
resource_name="labeledTfrecordUriPrefix",
uri=UriFactory.create_uri(
flattened_graph_metadata_pb.supervised_link_based_task_output.labeled_tfrecord_uri_prefix
),
file_name_suffix=".tfrecord",
)
assert_asset_exists(
resource_name="unlabeledTfrecordUriPrefix",
uri=UriFactory.create_uri(
flattened_graph_metadata_pb.supervised_link_based_task_output.unlabeled_tfrecord_uri_prefix
),
file_name_suffix=".tfrecord",
)