Source code for gigl.src.post_process.utils.unenumeration
import concurrent
from typing import List
from google.cloud import bigquery
import gigl.src.data_preprocessor.lib.enumerate.queries as enumeration_queries
import gigl.src.inference.v1.lib.queries as inference_queries
from gigl.common.env_config import get_available_cpus
from gigl.common.logger import Logger
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types.graph_data import NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.bq import BqUtils
from gigl.src.inference.lib.assets import InferenceAssets
from gigl.src.inference.v1.lib.inference_output_schema import (
    DEFAULT_EMBEDDINGS_TABLE_SCHEMA,
    DEFAULT_PREDICTIONS_TABLE_SCHEMA,
)
from snapchat.research.gbml import preprocessed_metadata_pb2
def _unenumerate_single_inferred_asset(
    inference_output_enumerated_assets_table: str,
    inference_output_node_id_field: str,
    inference_output_unenumerated_assets_table: str,
    enumerator_mapping_table: str,
):
    """Runs un-enumeration query on a single inferred asset (prediction or embedding table).
    Args:
        inference_output_enumerated_assets_table (str): BQ table which contains assets keyed off enumerated node id.
        inference_output_node_id_field (str): Field containing enumerated node ids in the enumerated_assets_table table.
        inference_output_unenumerated_assets_table (str): BQ table which contains "final" unenumerated assets.
        enumerator_mapping_table (str): BQ table which contains mapping between enumerated and original ids.
    """
    # TODO: relevant resource config args should be passed through instead of using global config
    resource_config = get_resource_config()
    bq_utils = BqUtils(project=resource_config.project)
    bq_utils.run_query(
        query=inference_queries.UNENUMERATION_QUERY.format(
            enumerated_assets_table=inference_output_enumerated_assets_table,
            mapping_table=enumerator_mapping_table,
            node_id_field=inference_output_node_id_field,
            original_node_id_field=enumeration_queries.DEFAULT_ORIGINAL_NODE_ID_FIELD,
            enumerated_int_id_field=enumeration_queries.DEFAULT_ENUMERATED_NODE_ID_FIELD,
        ),
        labels=resource_config.get_resource_labels(component=GiGLComponents.Inferencer),
        destination=inference_output_unenumerated_assets_table,
        write_disposition=bigquery.job.WriteDisposition.WRITE_TRUNCATE,
    )
[docs]
def unenumerate_all_inferred_bq_assets(gbml_config_pb_wrapper: GbmlConfigPbWrapper):
    """Un-enumerates assets that are produced by inference.  These assets include
    embeddings and/or predictions.  The node ids in these outputs are enumerated
    as according to logic specified in the Data Preprocessor component.
    Args:
        gbml_config_pb_wrapper (GbmlConfigPbWrapper): _description_
    """
    # First we need to read all the node types in inferencer output and get their condensed node types.
    inference_output_map = (
        gbml_config_pb_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map
    )
    node_type_to_condensed_node_type_map = (
        gbml_config_pb_wrapper.graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map
    )
    # We then collect all the assets that need to be un-enumerated and their mapping tables
    enumerated_assets_output_tables: List = list()
    enumerated_node_id_fields: List = list()
    unenumerated_assets_output_tables: List = list()
    mapping_bq_tables: List = list()
    for node_type, inference_output in inference_output_map.items():
        # Get the condensed node type for the inference node type.
        condensed_inference_node_type = node_type_to_condensed_node_type_map[
            NodeType(node_type)
        ]
        logger.info(
            f"Processing node type: {node_type} with condensed node type: {condensed_inference_node_type}"
        )
        preprocessed_metadata_pb: preprocessed_metadata_pb2.PreprocessedMetadata = (
            gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.preprocessed_metadata
        )
        node_type_metadata_map = (
            preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata
        )
        node_metadata_output = node_type_metadata_map[
            int(condensed_inference_node_type)
        ]
        mapping_bq_table = (
            node_metadata_output.enumerated_node_ids_bq_table
        )  # schema; node_id, int_id (opinionated, specified by enumerator queries)
        if not mapping_bq_table:
            logger.info(
                f"Skipping un-enumeration for node_type={node_type} since no mapping table exists"
            )
            continue
        logger.info(f"Found mapping table to be: {mapping_bq_table}")
        unenumerated_embedding_table_path: str = (
            InferenceAssets.get_unenumerated_embedding_table_path(
                gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type
            )
        )
        if unenumerated_embedding_table_path:
            enumerated_assets_output_tables.append(
                InferenceAssets.get_enumerated_embedding_table_path(
                    gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type
                )
            )
            enumerated_node_id_fields.append(DEFAULT_EMBEDDINGS_TABLE_SCHEMA.node_field)
            unenumerated_assets_output_tables.append(unenumerated_embedding_table_path)
            mapping_bq_tables.append(mapping_bq_table)
        unenumerated_prediction_table_path: str = (
            InferenceAssets.get_unenumerated_prediction_table_path(
                gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type
            )
        )
        if unenumerated_prediction_table_path:
            enumerated_assets_output_tables.append(
                InferenceAssets.get_enumerated_predictions_table_path(
                    gbml_config_pb_wrapper=gbml_config_pb_wrapper, node_type=node_type
                )
            )
            enumerated_node_id_fields.append(
                DEFAULT_PREDICTIONS_TABLE_SCHEMA.node_field
            )
            unenumerated_assets_output_tables.append(unenumerated_prediction_table_path)
            mapping_bq_tables.append(mapping_bq_table)
    # Finally, we un-enumerate all the enumerated assets in parallel.
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=get_available_cpus()
    ) as executor:
        futures: list[concurrent.futures.Future] = list()
        for (
            enumerated_assets_table,
            node_id_field,
            unenumerated_assets_table,
            mapping_bq_table,
        ) in zip(
            enumerated_assets_output_tables,
            enumerated_node_id_fields,
            unenumerated_assets_output_tables,
            mapping_bq_tables,
        ):
            future = executor.submit(
                _unenumerate_single_inferred_asset,
                inference_output_enumerated_assets_table=enumerated_assets_table,
                inference_output_node_id_field=node_id_field,
                inference_output_unenumerated_assets_table=unenumerated_assets_table,
                enumerator_mapping_table=mapping_bq_table,
            )
            futures.append(future)
        for fut in concurrent.futures.as_completed(futures):
            fut.result()  # Rereaise any exceptions
    logger.info(f"Output to tables: {', '.join(unenumerated_assets_output_tables)}")