Source code for gigl.src.mocking.dataset_asset_mocker

from typing import Optional

import gigl.src.mocking.lib.constants as mocking_constants
from gigl.common import GcsUri
from gigl.common.logger import Logger
from gigl.common.utils.gcs import GcsUtils
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from gigl.src.common.utils.bq import BqUtils
from gigl.src.config_populator.config_populator import ConfigPopulator
from gigl.src.mocking.lib import (
    mock_input_for_data_preprocessor,
    mock_input_for_inference,
    mock_input_for_split_generator,
    mock_input_for_subgraph_sampler,
    mock_input_for_trainer,
    mock_output_for_inference,
)
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo
from snapchat.research.gbml import gbml_config_pb2, graph_schema_pb2

[docs] logger = Logger()
[docs] class DatasetAssetMocker: """ Enables functionality to mock the input / output assets of all components based on input graph data. Useful to (re-)generate assets which can be used for testing. """ def __init__(self) -> None: self.__proto_utils = ProtoUtils() def _update_supervised_node_classification_config_paths( self, pb: gbml_config_pb2.GbmlConfig, root_node_type: Optional[NodeType], ): modeling_task_spec_path = ( "gigl." "src." "common." "modeling_task_specs." "node_classification_modeling_task_spec." "NodeClassificationModelingTaskSpec" ) assert ( root_node_type in self._mocked_dataset_info.num_node_distinct_labels ), f"Need labels for node type {root_node_type} to mock for supervised tasks." kwargs = { "batch_size": "16", "out_dim": str( self._mocked_dataset_info.num_node_distinct_labels[root_node_type] ), "num_epochs": "1", } pb.trainer_config.trainer_cls_path = modeling_task_spec_path pb.trainer_config.trainer_args.update(kwargs) pb.inferencer_config.inferencer_cls_path = modeling_task_spec_path pb.inferencer_config.inferencer_args.update(kwargs) task_output = ( pb.shared_config.flattened_graph_metadata.supervised_node_classification_output ) task_output.labeled_tfrecord_uri_prefix = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=task_output.labeled_tfrecord_uri_prefix, version=self._version ) ) task_output.unlabeled_tfrecord_uri_prefix = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=task_output.unlabeled_tfrecord_uri_prefix, version=self._version ) ) task_dataset = ( pb.shared_config.dataset_metadata.supervised_node_classification_dataset ) task_dataset.train_data_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=task_dataset.train_data_uri, version=self._version ) ) task_dataset.val_data_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=task_dataset.val_data_uri, version=self._version ) ) task_dataset.test_data_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=task_dataset.test_data_uri, version=self._version ) ) node_type_to_inferencer_output_info_map = ( pb.shared_config.inference_metadata.node_type_to_inferencer_output_info_map ) for node_type in node_type_to_inferencer_output_info_map: node_type_to_inferencer_output_info_map[ node_type ].predictions_path = mocking_constants.update_bq_table_with_test_assets_and_version( bq_table=node_type_to_inferencer_output_info_map[ node_type ].predictions_path, version=self._version, ) node_type_to_inferencer_output_info_map[ node_type ].embeddings_path = mocking_constants.update_bq_table_with_test_assets_and_version( bq_table=node_type_to_inferencer_output_info_map[ node_type ].embeddings_path, version=self._version, ) def _update_node_anchor_based_link_prediction_config_paths( self, pb: gbml_config_pb2.GbmlConfig, ): modeling_task_spec_path = ( "gigl." "src." "common." "modeling_task_specs." "node_anchor_based_link_prediction_modeling_task_spec." "NodeAnchorBasedLinkPredictionModelingTaskSpec" ) kwargs = { "main_sample_batch_size": "4", "random_negative_sample_batch_size": "4", "random_negative_sample_batch_size_for_evaluation": "4", "num_val_batches": "4", "num_test_batches": "4", "val_every_num_batches": "4", "early_stop_patience": "1", } graph_metadata_pb_wrapper = GraphMetadataPbWrapper( graph_metadata_pb=pb.graph_metadata ) if graph_metadata_pb_wrapper.is_heterogeneous: kwargs.update( {"gnn_model_class_path": "gigl.src.common.models.pyg.heterogeneous.HGT"} ) pb.trainer_config.trainer_cls_path = modeling_task_spec_path pb.trainer_config.trainer_args.update(kwargs) pb.inferencer_config.inferencer_cls_path = modeling_task_spec_path pb.inferencer_config.inferencer_args.update(kwargs) task_output = ( pb.shared_config.flattened_graph_metadata.node_anchor_based_link_prediction_output ) task_output.tfrecord_uri_prefix = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=task_output.tfrecord_uri_prefix, version=self._version ) ) for ( node_type, random_negative_tfrecord_uri_prefix, ) in task_output.node_type_to_random_negative_tfrecord_uri_prefix.items(): task_output.node_type_to_random_negative_tfrecord_uri_prefix[ node_type ] = mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=random_negative_tfrecord_uri_prefix, version=self._version ) task_dataset = ( pb.shared_config.dataset_metadata.node_anchor_based_link_prediction_dataset ) task_dataset.train_main_data_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=task_dataset.train_main_data_uri, version=self._version ) ) task_dataset.test_main_data_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=task_dataset.test_main_data_uri, version=self._version ) ) task_dataset.val_main_data_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=task_dataset.val_main_data_uri, version=self._version ) ) for ( node_type, random_negative_tfrecord_uri_prefix, ) in task_dataset.train_node_type_to_random_negative_data_uri.items(): task_dataset.train_node_type_to_random_negative_data_uri[ node_type ] = mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=random_negative_tfrecord_uri_prefix, version=self._version ) for ( node_type, random_negative_tfrecord_uri_prefix, ) in task_dataset.val_node_type_to_random_negative_data_uri.items(): task_dataset.val_node_type_to_random_negative_data_uri[ node_type ] = mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=random_negative_tfrecord_uri_prefix, version=self._version ) for ( node_type, random_negative_tfrecord_uri_prefix, ) in task_dataset.test_node_type_to_random_negative_data_uri.items(): task_dataset.test_node_type_to_random_negative_data_uri[ node_type ] = mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=random_negative_tfrecord_uri_prefix, version=self._version ) inference_metadata = pb.shared_config.inference_metadata for node_type in inference_metadata.node_type_to_inferencer_output_info_map: inference_metadata.node_type_to_inferencer_output_info_map[ node_type ].embeddings_path = mocking_constants.update_bq_table_with_test_assets_and_version( bq_table=inference_metadata.node_type_to_inferencer_output_info_map[ node_type ].embeddings_path, version=self._version, ) def _prepare_frozen_gbml_config_shared( self, task_metadata_pb: gbml_config_pb2.GbmlConfig.TaskMetadata ) -> gbml_config_pb2.GbmlConfig: applied_task_identifier = AppliedTaskIdentifier(self._mocked_dataset_info.name) graph_metadata_pb = ( self._mocked_dataset_info.graph_metadata_pb_wrapper.graph_metadata_pb ) template_gbml_config_pb = gbml_config_pb2.GbmlConfig( task_metadata=task_metadata_pb, graph_metadata=graph_metadata_pb, ) config_populator = ConfigPopulator() frozen_gbml_config_pb = config_populator._populate_frozen_gbml_config_pb( applied_task_identifier=applied_task_identifier, template_gbml_config_pb=template_gbml_config_pb, ) frozen_gbml_config_pb.shared_config.preprocessed_metadata_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=frozen_gbml_config_pb.shared_config.preprocessed_metadata_uri, version=self._version, ) ) trained_model_metadata = ( frozen_gbml_config_pb.shared_config.trained_model_metadata ) trained_model_metadata.trained_model_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=trained_model_metadata.trained_model_uri, version=self._version ) ) trained_model_metadata.scripted_model_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=trained_model_metadata.scripted_model_uri, version=self._version ) ) trained_model_metadata.eval_metrics_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=trained_model_metadata.eval_metrics_uri, version=self._version ) ) trained_model_metadata.tensorboard_logs_uri = ( mocking_constants.update_gcs_uri_with_test_assets_and_version( uri_str=trained_model_metadata.tensorboard_logs_uri, version=self._version, ) ) return frozen_gbml_config_pb def _populate_and_write_frozen_gbml_config( self, frozen_gbml_config_pb: gbml_config_pb2.GbmlConfig ) -> None: self._frozen_gbml_config_pb = frozen_gbml_config_pb logger.info(self._frozen_gbml_config_pb) frozen_gbml_config_gcs_uri = ( mocking_constants.get_example_task_frozen_gbml_config_gcs_path( task_name=self._mocked_dataset_info.name, version=self._version ) ) self.__proto_utils.write_proto_to_yaml( proto=self._frozen_gbml_config_pb, uri=frozen_gbml_config_gcs_uri ) def _prepare_supervised_node_classification_frozen_gbml_config( self, sample_node_type: NodeType ): task_metadata_pb = gbml_config_pb2.GbmlConfig.TaskMetadata( node_based_task_metadata=gbml_config_pb2.GbmlConfig.TaskMetadata.NodeBasedTaskMetadata( supervision_node_types=[str(sample_node_type)] ) ) frozen_gbml_config_pb = self._prepare_frozen_gbml_config_shared( task_metadata_pb=task_metadata_pb ) self._update_supervised_node_classification_config_paths( pb=frozen_gbml_config_pb, root_node_type=sample_node_type ) self._populate_and_write_frozen_gbml_config(frozen_gbml_config_pb) def _prepare_node_anchor_based_link_prediction_frozen_gbml_config( self, sample_edge_type: EdgeType ): task_metadata_pb = gbml_config_pb2.GbmlConfig.TaskMetadata( node_anchor_based_link_prediction_task_metadata=gbml_config_pb2.GbmlConfig.TaskMetadata.NodeAnchorBasedLinkPredictionTaskMetadata( supervision_edge_types=[ graph_schema_pb2.EdgeType( src_node_type=sample_edge_type.src_node_type, relation=sample_edge_type.relation, dst_node_type=sample_edge_type.dst_node_type, ) ] ) ) frozen_gbml_config_pb = self._prepare_frozen_gbml_config_shared( task_metadata_pb=task_metadata_pb ) self._update_node_anchor_based_link_prediction_config_paths( pb=frozen_gbml_config_pb ) self._populate_and_write_frozen_gbml_config(frozen_gbml_config_pb) def _mock_supervised_node_classification_assets(self): # Prepare GCS and BQ assets / environment. self._prepare_env() # Prepare frozen GbmlConfig. assert ( self._mocked_dataset_info.sample_node_type is not None ), f"Need defined sample_node_type to mock for {TaskMetadataType.NODE_BASED_TASK} task." self._prepare_supervised_node_classification_frozen_gbml_config( sample_node_type=self._mocked_dataset_info.sample_node_type ) # Upload assets to BQ mock_input_for_data_preprocessor.generate_bigquery_assets( mocked_dataset_info=self._mocked_dataset_info, version=self._version ) # Mock SubgraphSampler inputs ("run Data Preprocessor") mock_input_for_subgraph_sampler.generate_preprocessed_tfrecord_data( mocked_dataset_info=self._mocked_dataset_info, version=self._version, gbml_config_pb=self._frozen_gbml_config_pb, ) # Mock SplitGenerator inputs ("run Subgraph Sampler") hetero_data = mock_input_for_split_generator.build_and_write_supervised_node_classification_subgraph_samples_from_mocked_dataset_info( mocked_dataset_info=self._mocked_dataset_info, root_node_type=self._mocked_dataset_info.sample_node_type, gbml_config_pb=self._frozen_gbml_config_pb, ) # Mock Trainer inputs ("run Split Generator") mock_input_for_trainer.split_and_write_supervised_node_classification_subgraph_samples_from_mocked_dataset_info( mocked_dataset_info=self._mocked_dataset_info, root_node_type=self._mocked_dataset_info.sample_node_type, gbml_config_pb=self._frozen_gbml_config_pb, hetero_data=hetero_data, ) # Mock Inferencer inputs ("run Trainer") mock_input_for_inference.train_model( gbml_config_pb=self._frozen_gbml_config_pb, ) # Mock Inferencer outputs ("run Inferencer") mock_output_for_inference.infer_model( gbml_config_pb=self._frozen_gbml_config_pb, ) def _mock_node_anchor_based_link_prediction_assets(self): # Prepare GCS and BQ assets / environment. self._prepare_env() # Prepare frozen GbmlConfig. assert ( self._mocked_dataset_info.sample_edge_type is not None ), f"Need defined sample_edge_type to mock for {TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK} task." self._prepare_node_anchor_based_link_prediction_frozen_gbml_config( sample_edge_type=self._mocked_dataset_info.sample_edge_type ) # Upload assets to BQ mock_input_for_data_preprocessor.generate_bigquery_assets( mocked_dataset_info=self._mocked_dataset_info, version=self._version ) # Mock SubgraphSampler inputs ("run Data Preprocessor") mock_input_for_subgraph_sampler.generate_preprocessed_tfrecord_data( mocked_dataset_info=self._mocked_dataset_info, version=self._version, gbml_config_pb=self._frozen_gbml_config_pb, ) # Mock SplitGenerator inputs ("run Subgraph Sampler") hetero_data = mock_input_for_split_generator.build_and_write_node_anchor_link_prediction_subgraph_samples_from_mocked_dataset_info( mocked_dataset_info=self._mocked_dataset_info, sample_edge_type=self._mocked_dataset_info.sample_edge_type, gbml_config_pb=self._frozen_gbml_config_pb, ) # Mock Trainer inputs ("run Split Generator") mock_input_for_trainer.split_and_write_node_anchor_link_prediction_subgraph_samples_from_mocked_dataset_info( mocked_dataset_info=self._mocked_dataset_info, sample_edge_type=self._mocked_dataset_info.sample_edge_type, gbml_config_pb=self._frozen_gbml_config_pb, hetero_data=hetero_data, ) # Mock Inferencer inputs ("run Trainer") mock_input_for_inference.train_model( gbml_config_pb=self._frozen_gbml_config_pb, ) # Mock Inferencer outputs ("run Inferencer") mock_output_for_inference.infer_model( gbml_config_pb=self._frozen_gbml_config_pb, ) def _prepare_env(self): bq_utils = BqUtils() bq_utils.create_bq_dataset( dataset_id=mocking_constants.MOCK_DATA_BQ_DATASET_NAME, exists_ok=True ) gcs_utils = GcsUtils() gcs_utils.delete_files_in_bucket_dir( gcs_path=mocking_constants.get_example_task_static_assets_gcs_dir( task_name=self._mocked_dataset_info.name, version=self._version ) )
[docs] def mock_assets(self, mocked_dataset_info: MockedDatasetInfo) -> GcsUri: self._mocked_dataset_info = mocked_dataset_info assert ( mocked_dataset_info.version is not None ), "Need defined version to mock assets." self._version = mocked_dataset_info.version if mocked_dataset_info.task_metadata_type == TaskMetadataType.NODE_BASED_TASK: assert ( mocked_dataset_info.sample_node_type is not None ), f"Need defined sample_node_type to mock for {TaskMetadataType.NODE_BASED_TASK} task." self._mock_supervised_node_classification_assets() elif ( mocked_dataset_info.task_metadata_type == TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK ): assert ( mocked_dataset_info.sample_edge_type is not None ), f"Need defined sample_edge_type to mock for {TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK} task." self._mock_node_anchor_based_link_prediction_assets() else: raise NotImplementedError frozen_gbml_config_uri = ( mocking_constants.get_example_task_frozen_gbml_config_gcs_path( task_name=self._mocked_dataset_info.name, version=self._version ) ) return frozen_gbml_config_uri