Source code for gigl.src.inference.lib.assets

from typing import List

from gigl.common import GcsUri
from gigl.common.logger import Logger
from gigl.common.utils.gcs import GcsUtils
from gigl.src.common.constants import gcs as gcs_constants
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.bq import BqUtils

[docs] logger = Logger()
[docs] class InferenceAssets: """ Utility class for managing temp and permanent inferencer assets. """ @staticmethod
[docs] def get_unenumerated_embedding_table_path( gbml_config_pb_wrapper: GbmlConfigPbWrapper, node_type: str ) -> str: """ Get the unenumerated embedding table path for a given node type. i.e. table contains the embeddings indexed by original node id """ return gbml_config_pb_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map[ node_type ].embeddings_path
@staticmethod
[docs] def get_unenumerated_prediction_table_path( gbml_config_pb_wrapper: GbmlConfigPbWrapper, node_type: str ) -> str: """ Get the unenumerated embedding table path for a given node type. i.e. table contains the embeddings indexed by original node id """ return gbml_config_pb_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map[ node_type ].predictions_path
@staticmethod
[docs] def get_enumerated_embedding_table_path( gbml_config_pb_wrapper: GbmlConfigPbWrapper, node_type: str ) -> str: """ Get the enumerated embedding table path for a given node type. i.e. table should containe (enumerated_node_id: int) ---> embedding """ unenumerated_bq_table_path = ( InferenceAssets.get_unenumerated_embedding_table_path( gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type ) ) # This is optional and may be none; we traditionally dont have any asserts here so matching that style # This should probably change in the future if not unenumerated_bq_table_path: return "" return InferenceAssets._create_enumerated_bq_table_name( unenumerated_bq_table_path=unenumerated_bq_table_path )
@staticmethod
[docs] def get_enumerated_predictions_table_path( gbml_config_pb_wrapper: GbmlConfigPbWrapper, node_type: str ) -> str: """ Get the enumerated predictions table path for a given node type. i.e. table should containe (enumerated_node_id: int) ---> prediction """ unenumerated_bq_table_path = ( InferenceAssets.get_unenumerated_prediction_table_path( gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type ) ) # This is optional and may be none; we traditionally dont have any asserts here so matching that style # This should probably change in the future if not unenumerated_bq_table_path: return "" return InferenceAssets._create_enumerated_bq_table_name( unenumerated_bq_table_path=unenumerated_bq_table_path )
@staticmethod
[docs] def prepare_staging_paths( applied_task_identifier: AppliedTaskIdentifier, gbml_config_pb_wrapper: GbmlConfigPbWrapper, project: str, ) -> None: """ Prepare staging paths for inferencer assets by clearing the paths that inferencer would be writing to, to avoid clobbering of data. """ logger.info("Preparing staging paths for Inferencer...") InferenceAssets._delete_temp_gcs_files( gbml_config_pb_wrapper=gbml_config_pb_wrapper, applied_task_identifier=applied_task_identifier, project=project, ) InferenceAssets._delete_bq_output_tables( gbml_config_pb_wrapper=gbml_config_pb_wrapper, project=project, ) logger.info("Staging paths for Inferencer prepared.")
@staticmethod
[docs] def get_gcs_asset_write_path_prefix( applied_task_identifier: AppliedTaskIdentifier, bq_table_path: str ) -> GcsUri: """ Formulated an intermediary GCS path for writing embeddings or predictions based on the bq table path Args: applied_task_identifier (AppliedTaskIdentifier): The name provided for the gigl job bq_table_path (str): Path to the table for embeddings or predictions output Returns: GcsUri: The path to the gcs folder based on the bq table path """ formatted_gcs_path = bq_table_path.replace(".", "_").replace(":", "__") # TODO (mkolodner): Update code to write to gcs in permanent storage location for enabling gcs inferencer output # TODO (svij): Ideally this should be writing to gcs paths formulated by: # gigl.src.common.constants.gcs._INFERENCER_PREFIX return GcsUri.join( gcs_constants.get_applied_task_temp_gcs_path( applied_task_identifier=applied_task_identifier ), f"{formatted_gcs_path}/", )
@staticmethod def _delete_temp_gcs_files( gbml_config_pb_wrapper: GbmlConfigPbWrapper, applied_task_identifier: AppliedTaskIdentifier, project: str, ): """ Delete temporary GCS files created by the inferencer. """ logger.info("Deleting temporary GCS files...") gcs_utils = GcsUtils(project=project) active_bq_table_paths = [] for ( node_type ) in ( gbml_config_pb_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map.keys() ): bq_table_path_unenumerated_predictions = ( InferenceAssets.get_unenumerated_prediction_table_path( gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type ) ) bq_table_path_unenumerated_embeddings = ( InferenceAssets.get_unenumerated_embedding_table_path( gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type ) ) if bq_table_path_unenumerated_predictions: active_bq_table_paths.append(bq_table_path_unenumerated_predictions) active_bq_table_paths.append( InferenceAssets._create_enumerated_bq_table_name( unenumerated_bq_table_path=bq_table_path_unenumerated_predictions ) ) if bq_table_path_unenumerated_embeddings: active_bq_table_paths.append(bq_table_path_unenumerated_embeddings) active_bq_table_paths.append( InferenceAssets._create_enumerated_bq_table_name( unenumerated_bq_table_path=bq_table_path_unenumerated_embeddings ) ) for bq_table_path in active_bq_table_paths: table_gcs_write_path_uri: GcsUri = ( InferenceAssets.get_gcs_asset_write_path_prefix( applied_task_identifier=applied_task_identifier, bq_table_path=bq_table_path, ) ) gcs_utils.delete_files_in_bucket_dir(table_gcs_write_path_uri) @staticmethod def _delete_bq_output_tables( gbml_config_pb_wrapper: GbmlConfigPbWrapper, project: str, ): logger.info("Deleting BigQuery output tables...") bq_utils = BqUtils(project=project) active_bq_table_paths = [] for ( node_type ) in ( gbml_config_pb_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map.keys() ): bq_table_path_unenumerated_predictions = ( InferenceAssets.get_unenumerated_prediction_table_path( gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type ) ) bq_table_path_unenumerated_embeddings = ( InferenceAssets.get_unenumerated_embedding_table_path( gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type ) ) if bq_table_path_unenumerated_predictions: active_bq_table_paths.append(bq_table_path_unenumerated_predictions) active_bq_table_paths.append( InferenceAssets._create_enumerated_bq_table_name( unenumerated_bq_table_path=bq_table_path_unenumerated_predictions ) ) if bq_table_path_unenumerated_embeddings: active_bq_table_paths.append(bq_table_path_unenumerated_embeddings) active_bq_table_paths.append( InferenceAssets._create_enumerated_bq_table_name( unenumerated_bq_table_path=bq_table_path_unenumerated_embeddings ) ) for bq_table_path in active_bq_table_paths: bq_utils.delete_bq_table_if_exist(bq_table_path=bq_table_path) @staticmethod def _create_enumerated_bq_table_name(unenumerated_bq_table_path: str) -> str: """ embeddingsPath contains the unenumerated embeddings table path. This function returns the input enumerated embeddings table path. bq_table_path: str: The path to the enumerated embeddings table. Format should be project-id.dataset-id.table-id """ bq_table_path_list: List[str] = unenumerated_bq_table_path.split(".") assert ( len(bq_table_path_list) == 3 ), f"Invalid bq_table_path: {unenumerated_bq_table_path}, expected format: project-id.dataset-id.table-id; got" project_id, dataset_id, table_id = bq_table_path_list return f"{project_id}.{dataset_id}.enumerated_{table_id}"