from typing import Optional
import gigl.src.mocking.lib.constants as mocking_constants
from gigl.common import GcsUri
from gigl.common.logger import Logger
from gigl.common.utils.gcs import GcsUtils
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from gigl.src.common.utils.bq import BqUtils
from gigl.src.config_populator.config_populator import ConfigPopulator
from gigl.src.mocking.lib import (
mock_input_for_data_preprocessor,
mock_input_for_inference,
mock_input_for_split_generator,
mock_input_for_subgraph_sampler,
mock_input_for_trainer,
mock_output_for_inference,
)
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo
from snapchat.research.gbml import gbml_config_pb2, graph_schema_pb2
[docs]
class DatasetAssetMocker:
"""
Enables functionality to mock the input / output assets of all components based on input graph data.
Useful to (re-)generate assets which can be used for testing.
"""
def __init__(self) -> None:
self.__proto_utils = ProtoUtils()
def _update_supervised_node_classification_config_paths(
self,
pb: gbml_config_pb2.GbmlConfig,
root_node_type: Optional[NodeType],
):
modeling_task_spec_path = (
"gigl."
"src."
"common."
"modeling_task_specs."
"node_classification_modeling_task_spec."
"NodeClassificationModelingTaskSpec"
)
assert (
root_node_type in self._mocked_dataset_info.num_node_distinct_labels
), f"Need labels for node type {root_node_type} to mock for supervised tasks."
kwargs = {
"batch_size": "16",
"out_dim": str(
self._mocked_dataset_info.num_node_distinct_labels[root_node_type]
),
"num_epochs": "1",
}
pb.trainer_config.trainer_cls_path = modeling_task_spec_path
pb.trainer_config.trainer_args.update(kwargs)
pb.inferencer_config.inferencer_cls_path = modeling_task_spec_path
pb.inferencer_config.inferencer_args.update(kwargs)
task_output = (
pb.shared_config.flattened_graph_metadata.supervised_node_classification_output
)
task_output.labeled_tfrecord_uri_prefix = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=task_output.labeled_tfrecord_uri_prefix, version=self._version
)
)
task_output.unlabeled_tfrecord_uri_prefix = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=task_output.unlabeled_tfrecord_uri_prefix, version=self._version
)
)
task_dataset = (
pb.shared_config.dataset_metadata.supervised_node_classification_dataset
)
task_dataset.train_data_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=task_dataset.train_data_uri, version=self._version
)
)
task_dataset.val_data_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=task_dataset.val_data_uri, version=self._version
)
)
task_dataset.test_data_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=task_dataset.test_data_uri, version=self._version
)
)
node_type_to_inferencer_output_info_map = (
pb.shared_config.inference_metadata.node_type_to_inferencer_output_info_map
)
for node_type in node_type_to_inferencer_output_info_map:
node_type_to_inferencer_output_info_map[
node_type
].predictions_path = mocking_constants.update_bq_table_with_test_assets_and_version(
bq_table=node_type_to_inferencer_output_info_map[
node_type
].predictions_path,
version=self._version,
)
node_type_to_inferencer_output_info_map[
node_type
].embeddings_path = mocking_constants.update_bq_table_with_test_assets_and_version(
bq_table=node_type_to_inferencer_output_info_map[
node_type
].embeddings_path,
version=self._version,
)
def _update_node_anchor_based_link_prediction_config_paths(
self,
pb: gbml_config_pb2.GbmlConfig,
):
modeling_task_spec_path = (
"gigl."
"src."
"common."
"modeling_task_specs."
"node_anchor_based_link_prediction_modeling_task_spec."
"NodeAnchorBasedLinkPredictionModelingTaskSpec"
)
kwargs = {
"main_sample_batch_size": "4",
"random_negative_sample_batch_size": "4",
"random_negative_sample_batch_size_for_evaluation": "4",
"num_val_batches": "4",
"num_test_batches": "4",
"val_every_num_batches": "4",
"early_stop_patience": "1",
}
graph_metadata_pb_wrapper = GraphMetadataPbWrapper(
graph_metadata_pb=pb.graph_metadata
)
if graph_metadata_pb_wrapper.is_heterogeneous:
kwargs.update(
{"gnn_model_class_path": "gigl.src.common.models.pyg.heterogeneous.HGT"}
)
pb.trainer_config.trainer_cls_path = modeling_task_spec_path
pb.trainer_config.trainer_args.update(kwargs)
pb.inferencer_config.inferencer_cls_path = modeling_task_spec_path
pb.inferencer_config.inferencer_args.update(kwargs)
task_output = (
pb.shared_config.flattened_graph_metadata.node_anchor_based_link_prediction_output
)
task_output.tfrecord_uri_prefix = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=task_output.tfrecord_uri_prefix, version=self._version
)
)
for (
node_type,
random_negative_tfrecord_uri_prefix,
) in task_output.node_type_to_random_negative_tfrecord_uri_prefix.items():
task_output.node_type_to_random_negative_tfrecord_uri_prefix[
node_type
] = mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=random_negative_tfrecord_uri_prefix, version=self._version
)
task_dataset = (
pb.shared_config.dataset_metadata.node_anchor_based_link_prediction_dataset
)
task_dataset.train_main_data_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=task_dataset.train_main_data_uri, version=self._version
)
)
task_dataset.test_main_data_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=task_dataset.test_main_data_uri, version=self._version
)
)
task_dataset.val_main_data_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=task_dataset.val_main_data_uri, version=self._version
)
)
for (
node_type,
random_negative_tfrecord_uri_prefix,
) in task_dataset.train_node_type_to_random_negative_data_uri.items():
task_dataset.train_node_type_to_random_negative_data_uri[
node_type
] = mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=random_negative_tfrecord_uri_prefix, version=self._version
)
for (
node_type,
random_negative_tfrecord_uri_prefix,
) in task_dataset.val_node_type_to_random_negative_data_uri.items():
task_dataset.val_node_type_to_random_negative_data_uri[
node_type
] = mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=random_negative_tfrecord_uri_prefix, version=self._version
)
for (
node_type,
random_negative_tfrecord_uri_prefix,
) in task_dataset.test_node_type_to_random_negative_data_uri.items():
task_dataset.test_node_type_to_random_negative_data_uri[
node_type
] = mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=random_negative_tfrecord_uri_prefix, version=self._version
)
inference_metadata = pb.shared_config.inference_metadata
for node_type in inference_metadata.node_type_to_inferencer_output_info_map:
inference_metadata.node_type_to_inferencer_output_info_map[
node_type
].embeddings_path = mocking_constants.update_bq_table_with_test_assets_and_version(
bq_table=inference_metadata.node_type_to_inferencer_output_info_map[
node_type
].embeddings_path,
version=self._version,
)
def _prepare_frozen_gbml_config_shared(
self, task_metadata_pb: gbml_config_pb2.GbmlConfig.TaskMetadata
) -> gbml_config_pb2.GbmlConfig:
applied_task_identifier = AppliedTaskIdentifier(self._mocked_dataset_info.name)
graph_metadata_pb = (
self._mocked_dataset_info.graph_metadata_pb_wrapper.graph_metadata_pb
)
template_gbml_config_pb = gbml_config_pb2.GbmlConfig(
task_metadata=task_metadata_pb,
graph_metadata=graph_metadata_pb,
)
config_populator = ConfigPopulator()
frozen_gbml_config_pb = config_populator._populate_frozen_gbml_config_pb(
applied_task_identifier=applied_task_identifier,
template_gbml_config_pb=template_gbml_config_pb,
)
frozen_gbml_config_pb.shared_config.preprocessed_metadata_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=frozen_gbml_config_pb.shared_config.preprocessed_metadata_uri,
version=self._version,
)
)
trained_model_metadata = (
frozen_gbml_config_pb.shared_config.trained_model_metadata
)
trained_model_metadata.trained_model_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=trained_model_metadata.trained_model_uri, version=self._version
)
)
trained_model_metadata.scripted_model_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=trained_model_metadata.scripted_model_uri, version=self._version
)
)
trained_model_metadata.eval_metrics_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=trained_model_metadata.eval_metrics_uri, version=self._version
)
)
trained_model_metadata.tensorboard_logs_uri = (
mocking_constants.update_gcs_uri_with_test_assets_and_version(
uri_str=trained_model_metadata.tensorboard_logs_uri,
version=self._version,
)
)
return frozen_gbml_config_pb
def _populate_and_write_frozen_gbml_config(
self, frozen_gbml_config_pb: gbml_config_pb2.GbmlConfig
) -> None:
self._frozen_gbml_config_pb = frozen_gbml_config_pb
logger.info(self._frozen_gbml_config_pb)
frozen_gbml_config_gcs_uri = (
mocking_constants.get_example_task_frozen_gbml_config_gcs_path(
task_name=self._mocked_dataset_info.name, version=self._version
)
)
self.__proto_utils.write_proto_to_yaml(
proto=self._frozen_gbml_config_pb, uri=frozen_gbml_config_gcs_uri
)
def _prepare_supervised_node_classification_frozen_gbml_config(
self, sample_node_type: NodeType
):
task_metadata_pb = gbml_config_pb2.GbmlConfig.TaskMetadata(
node_based_task_metadata=gbml_config_pb2.GbmlConfig.TaskMetadata.NodeBasedTaskMetadata(
supervision_node_types=[str(sample_node_type)]
)
)
frozen_gbml_config_pb = self._prepare_frozen_gbml_config_shared(
task_metadata_pb=task_metadata_pb
)
self._update_supervised_node_classification_config_paths(
pb=frozen_gbml_config_pb, root_node_type=sample_node_type
)
self._populate_and_write_frozen_gbml_config(frozen_gbml_config_pb)
def _prepare_node_anchor_based_link_prediction_frozen_gbml_config(
self, sample_edge_type: EdgeType
):
task_metadata_pb = gbml_config_pb2.GbmlConfig.TaskMetadata(
node_anchor_based_link_prediction_task_metadata=gbml_config_pb2.GbmlConfig.TaskMetadata.NodeAnchorBasedLinkPredictionTaskMetadata(
supervision_edge_types=[
graph_schema_pb2.EdgeType(
src_node_type=sample_edge_type.src_node_type,
relation=sample_edge_type.relation,
dst_node_type=sample_edge_type.dst_node_type,
)
]
)
)
frozen_gbml_config_pb = self._prepare_frozen_gbml_config_shared(
task_metadata_pb=task_metadata_pb
)
self._update_node_anchor_based_link_prediction_config_paths(
pb=frozen_gbml_config_pb
)
self._populate_and_write_frozen_gbml_config(frozen_gbml_config_pb)
def _mock_supervised_node_classification_assets(self):
# Prepare GCS and BQ assets / environment.
self._prepare_env()
# Prepare frozen GbmlConfig.
assert (
self._mocked_dataset_info.sample_node_type is not None
), f"Need defined sample_node_type to mock for {TaskMetadataType.NODE_BASED_TASK} task."
self._prepare_supervised_node_classification_frozen_gbml_config(
sample_node_type=self._mocked_dataset_info.sample_node_type
)
# Upload assets to BQ
mock_input_for_data_preprocessor.generate_bigquery_assets(
mocked_dataset_info=self._mocked_dataset_info, version=self._version
)
# Mock SubgraphSampler inputs ("run Data Preprocessor")
mock_input_for_subgraph_sampler.generate_preprocessed_tfrecord_data(
mocked_dataset_info=self._mocked_dataset_info,
version=self._version,
gbml_config_pb=self._frozen_gbml_config_pb,
)
# Mock SplitGenerator inputs ("run Subgraph Sampler")
hetero_data = mock_input_for_split_generator.build_and_write_supervised_node_classification_subgraph_samples_from_mocked_dataset_info(
mocked_dataset_info=self._mocked_dataset_info,
root_node_type=self._mocked_dataset_info.sample_node_type,
gbml_config_pb=self._frozen_gbml_config_pb,
)
# Mock Trainer inputs ("run Split Generator")
mock_input_for_trainer.split_and_write_supervised_node_classification_subgraph_samples_from_mocked_dataset_info(
mocked_dataset_info=self._mocked_dataset_info,
root_node_type=self._mocked_dataset_info.sample_node_type,
gbml_config_pb=self._frozen_gbml_config_pb,
hetero_data=hetero_data,
)
# Mock Inferencer inputs ("run Trainer")
mock_input_for_inference.train_model(
gbml_config_pb=self._frozen_gbml_config_pb,
)
# Mock Inferencer outputs ("run Inferencer")
mock_output_for_inference.infer_model(
gbml_config_pb=self._frozen_gbml_config_pb,
)
def _mock_node_anchor_based_link_prediction_assets(self):
# Prepare GCS and BQ assets / environment.
self._prepare_env()
# Prepare frozen GbmlConfig.
assert (
self._mocked_dataset_info.sample_edge_type is not None
), f"Need defined sample_edge_type to mock for {TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK} task."
self._prepare_node_anchor_based_link_prediction_frozen_gbml_config(
sample_edge_type=self._mocked_dataset_info.sample_edge_type
)
# Upload assets to BQ
mock_input_for_data_preprocessor.generate_bigquery_assets(
mocked_dataset_info=self._mocked_dataset_info, version=self._version
)
# Mock SubgraphSampler inputs ("run Data Preprocessor")
mock_input_for_subgraph_sampler.generate_preprocessed_tfrecord_data(
mocked_dataset_info=self._mocked_dataset_info,
version=self._version,
gbml_config_pb=self._frozen_gbml_config_pb,
)
# Mock SplitGenerator inputs ("run Subgraph Sampler")
hetero_data = mock_input_for_split_generator.build_and_write_node_anchor_link_prediction_subgraph_samples_from_mocked_dataset_info(
mocked_dataset_info=self._mocked_dataset_info,
sample_edge_type=self._mocked_dataset_info.sample_edge_type,
gbml_config_pb=self._frozen_gbml_config_pb,
)
# Mock Trainer inputs ("run Split Generator")
mock_input_for_trainer.split_and_write_node_anchor_link_prediction_subgraph_samples_from_mocked_dataset_info(
mocked_dataset_info=self._mocked_dataset_info,
sample_edge_type=self._mocked_dataset_info.sample_edge_type,
gbml_config_pb=self._frozen_gbml_config_pb,
hetero_data=hetero_data,
)
# Mock Inferencer inputs ("run Trainer")
mock_input_for_inference.train_model(
gbml_config_pb=self._frozen_gbml_config_pb,
)
# Mock Inferencer outputs ("run Inferencer")
mock_output_for_inference.infer_model(
gbml_config_pb=self._frozen_gbml_config_pb,
)
def _prepare_env(self):
bq_utils = BqUtils()
bq_utils.create_bq_dataset(
dataset_id=mocking_constants.MOCK_DATA_BQ_DATASET_NAME, exists_ok=True
)
gcs_utils = GcsUtils()
gcs_utils.delete_files_in_bucket_dir(
gcs_path=mocking_constants.get_example_task_static_assets_gcs_dir(
task_name=self._mocked_dataset_info.name, version=self._version
)
)
[docs]
def mock_assets(self, mocked_dataset_info: MockedDatasetInfo) -> GcsUri:
self._mocked_dataset_info = mocked_dataset_info
assert (
mocked_dataset_info.version is not None
), "Need defined version to mock assets."
self._version = mocked_dataset_info.version
if mocked_dataset_info.task_metadata_type == TaskMetadataType.NODE_BASED_TASK:
assert (
mocked_dataset_info.sample_node_type is not None
), f"Need defined sample_node_type to mock for {TaskMetadataType.NODE_BASED_TASK} task."
self._mock_supervised_node_classification_assets()
elif (
mocked_dataset_info.task_metadata_type
== TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK
):
assert (
mocked_dataset_info.sample_edge_type is not None
), f"Need defined sample_edge_type to mock for {TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK} task."
self._mock_node_anchor_based_link_prediction_assets()
else:
raise NotImplementedError
frozen_gbml_config_uri = (
mocking_constants.get_example_task_frozen_gbml_config_gcs_path(
task_name=self._mocked_dataset_info.name, version=self._version
)
)
return frozen_gbml_config_uri