Source code for gigl.src.inference.v1.gnn_inferencer

# Note this class will get deprecated in the future without notice
# Use python/gigl/src/inference/inferencer.py instead

import argparse
import concurrent.futures
import sys
import threading
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult
from apache_beam.runners.runner import PipelineState
from google.cloud import bigquery

from gigl.common import GcsUri, Uri, UriFactory
from gigl.common.env_config import get_available_cpus
from gigl.common.logger import Logger
from gigl.common.metrics.decorators import flushes_metrics, profileit
from gigl.common.utils import os_utils
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.metrics import TIMER_INFERENCER_S
from gigl.src.common.graph_builder.graph_builder_factory import GraphBuilderFactory
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.graph_data import NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.bq import BqUtils
from gigl.src.common.utils.metrics_service_provider import (
    get_metrics_service_instance,
    initialize_metrics,
)
from gigl.src.common.utils.model import load_state_dict_from_uri
from gigl.src.inference.lib.assets import InferenceAssets
from gigl.src.inference.v1.lib.base_inference_blueprint import BaseInferenceBlueprint
from gigl.src.inference.v1.lib.base_inferencer import BaseInferencer
from gigl.src.inference.v1.lib.inference_blueprint_factory import (
    InferenceBlueprintFactory,
)
from gigl.src.inference.v1.lib.utils import (
    get_inferencer_pipeline_component_for_single_node_type,
)
from snapchat.research.gbml.inference_metadata_pb2 import InferenceOutput

[docs] logger = Logger()
[docs] MAX_INFERENCER_NUM_WORKERS = 4
@dataclass
[docs] class InferencerOutputPaths: """ Dataclass containing the output path fields from running inference for a single node type. These fields are used to write files from gcs to bigquery. """
[docs] bq_inferencer_output_paths: InferenceOutput
[docs] temp_predictions_gcs_path: Optional[GcsUri]
[docs] temp_embedding_gcs_path: Optional[GcsUri]
[docs] class InferencerV1: """ ********** WILL BE DEPRECATED ********** Note this class will get deprecated in the future without notice Use python/gigl/src/inference/inferencer.py instead ********** WILL BE DEPRECATED ********** GiGL Component that runs inference of a trained model on samples generated by the Subgraph Sampler component and outputs embedding and/or prediction assets. """ __bq_utils: BqUtils __gbml_config_pb_wrapper: GbmlConfigPbWrapper @property
[docs] def bq_utils(self) -> BqUtils: if not self.__bq_utils: raise ValueError(f"bq_utils is not initialized before use.") return self.__bq_utils
@property
[docs] def gbml_config_pb_wrapper(self) -> GbmlConfigPbWrapper: if not self.__gbml_config_pb_wrapper: raise ValueError(f"gbml_config_pb_wrapper is not initialized before use.") return self.__gbml_config_pb_wrapper
[docs] def write_from_gcs_to_bq( self, schema: Dict[str, List[Dict[str, str]]], gcs_uri: GcsUri, bq_table_uri: str, ) -> None: """ Writes embeddings or predictions from gcs folder to bq table with specified schema Args: schema (Optional[Dict[str, List[Dict[str, str]]]): BQ Table schema for embeddings or predictions from inference output gcs_uri (GcsUri): GCS Folder for embeddings or predictions from inference output bq_table_uri (str): Path to the table for embeddings or predictions output """ assert schema is not None assert "fields" in schema field_schema = schema["fields"] logger.info(f"schema = {field_schema}") logger.info(f"loading from {gcs_uri} to BQ table: {bq_table_uri}") self.bq_utils.load_file_to_bq( source_path=GcsUri.join(gcs_uri, "*"), bq_path=bq_table_uri, job_config=bigquery.LoadJobConfig( source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON, write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, schema=field_schema, ), retry=True, ) logger.info(f"Finished loading to BQ table {bq_table_uri}")
[docs] def generate_inferencer_instance(self) -> BaseInferencer: kwargs: Dict[str, Any] = {} inferencer_class_path: str = ( self.gbml_config_pb_wrapper.inferencer_config.inferencer_cls_path ) kwargs = dict(self.gbml_config_pb_wrapper.inferencer_config.inferencer_args) inferencer_cls = os_utils.import_obj(inferencer_class_path) inferencer_instance: BaseInferencer try: inferencer_instance = inferencer_cls(**kwargs) assert isinstance(inferencer_instance, BaseInferencer) except Exception as e: logger.error(f"Could not instantiate class {inferencer_cls}: {e}") raise e model_save_path_uri = UriFactory.create_uri( self.gbml_config_pb_wrapper.shared_config.trained_model_metadata.trained_model_uri ) logger.info( f"Loading model state dict from: {model_save_path_uri}, for inferencer: {inferencer_instance}" ) model_state_dict = load_state_dict_from_uri(load_from_uri=model_save_path_uri) inferencer_instance.init_model( gbml_config_pb_wrapper=self.gbml_config_pb_wrapper, state_dict=model_state_dict, ) return inferencer_instance
def __infer_single_node_type( self, inference_blueprint: BaseInferenceBlueprint, applied_task_identifier: AppliedTaskIdentifier, custom_worker_image_uri: Optional[str], node_type: NodeType, uri_prefix_list: List[Uri], lock: threading.Lock, ) -> InferencerOutputPaths: """ Runs inference on a single node type Args: inference_blueprint (BaseInferenceBlueprint): Blueprint for running and saving inference for GBML pipelines applied_task_identifier (AppliedTaskIdentifier): Identifier for the GiGL job custom_worker_image_uri (Optional[str]): Uri to custom worker image node_type (NodeType): Node type being inferred uri_prefix_list (List[Uri]): List of prefixes for running inference for given node type lock (threading.Lock): lock to prevent race conditions when starting dataflow pipelines Returns: (InferencerOutputPaths): Dataclass with path fields for writing from gcs to bigquery for given node type """ node_type_to_inferencer_output_info_map = ( self.gbml_config_pb_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map ) # Sanity check that we have some paths defined for intended inferred assets. if not ( node_type_to_inferencer_output_info_map[node_type].embeddings_path or node_type_to_inferencer_output_info_map[node_type].predictions_path ): raise ValueError( f"Inference metadata for node type {node_type} is missing; must have at least one of " "embeddings_path or predictions_path defined." ) should_persist_predictions = bool( node_type_to_inferencer_output_info_map[node_type].predictions_path ) should_persist_embeddings = bool( node_type_to_inferencer_output_info_map[node_type].embeddings_path ) temp_predictions_gcs_path: Optional[GcsUri] temp_embeddings_gcs_path: Optional[GcsUri] if should_persist_predictions: temp_predictions_gcs_path = InferenceAssets.get_gcs_asset_write_path_prefix( applied_task_identifier=applied_task_identifier, bq_table_path=node_type_to_inferencer_output_info_map[ node_type ].predictions_path, ) else: temp_predictions_gcs_path = None if should_persist_embeddings: temp_embeddings_gcs_path = InferenceAssets.get_gcs_asset_write_path_prefix( applied_task_identifier=applied_task_identifier, bq_table_path=node_type_to_inferencer_output_info_map[ node_type ].embeddings_path, ) else: temp_embeddings_gcs_path = None with lock: logger.debug(f"Node Type {node_type} acquiring lock.") p = get_inferencer_pipeline_component_for_single_node_type( gbml_config_pb_wrapper=self.gbml_config_pb_wrapper, inference_blueprint=inference_blueprint, applied_task_identifier=applied_task_identifier, custom_worker_image_uri=custom_worker_image_uri, node_type=node_type, uri_prefix_list=uri_prefix_list, temp_predictions_gcs_path=temp_predictions_gcs_path, temp_embeddings_gcs_path=temp_embeddings_gcs_path, ) inferencer_pipeline_result = p.run() logger.debug(f"Node Type {node_type} releasing lock.") logger.info(f"Starting Dataflow job to run inference on {node_type} node type") inferencer_pipeline_result.wait_until_finish() logger.info(f"Finished Dataflow job to run inference on {node_type} node type") if isinstance(inferencer_pipeline_result, DataflowPipelineResult): pipeline_state: str = inferencer_pipeline_result.state if pipeline_state != PipelineState.DONE: raise RuntimeError( f"A dataflow pipeline failed, has state {pipeline_state}: {inferencer_pipeline_result}" ) return InferencerOutputPaths( bq_inferencer_output_paths=node_type_to_inferencer_output_info_map[ node_type ], temp_predictions_gcs_path=temp_predictions_gcs_path, temp_embedding_gcs_path=temp_embeddings_gcs_path, ) def __run( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, custom_worker_image_uri: Optional[str] = None, ): self.__gbml_config_pb_wrapper = ( GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=task_config_uri ) ) if self.gbml_config_pb_wrapper.shared_config.should_skip_inference: logger.info("Skipping inference as flag set in GbmlConfig") return inferencer_instance: BaseInferencer = self.generate_inferencer_instance() graph_builder = GraphBuilderFactory.get_graph_builder( backend_name=inferencer_instance.model.graph_backend # type: ignore ) inference_blueprint: BaseInferenceBlueprint = ( InferenceBlueprintFactory.get_inference_blueprint( gbml_config_pb_wrapper=self.gbml_config_pb_wrapper, inferencer_instance=inferencer_instance, graph_builder=graph_builder, ) ) node_type_to_inferencer_output_paths_map: Dict[ NodeType, InferencerOutputPaths ] = dict() dataflow_setup_lock = threading.Lock() # We kick off multiple Inferencer pipelines, each of which kicks off a setup.py sdist run. # sdist has race-condition issues for simultaneous runs: https://github.com/pypa/setuptools/issues/1222 # We have each thread take a lock when kicking off the pipelines to avoid this issue. num_workers = min(get_available_cpus(), MAX_INFERENCER_NUM_WORKERS) with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: logger.info(f"Using up to {num_workers} threads.") futures: Dict[ concurrent.futures.Future[InferencerOutputPaths], NodeType ] = dict() for ( node_type, uri_prefix_list, ) in ( inference_blueprint.get_inference_data_tf_record_uri_prefixes().items() ): # Launching one beam pipeline per node type future = executor.submit( self.__infer_single_node_type, inference_blueprint=inference_blueprint, applied_task_identifier=applied_task_identifier, custom_worker_image_uri=custom_worker_image_uri, node_type=node_type, uri_prefix_list=uri_prefix_list, lock=dataflow_setup_lock, ) futures.update({future: node_type}) for future in concurrent.futures.as_completed(futures): node_type = futures[future] try: inferencer_output_paths: InferencerOutputPaths = future.result() node_type_to_inferencer_output_paths_map[ node_type ] = inferencer_output_paths except Exception as e: logger.exception( f"{node_type} inferencer job failed due to a raised exception: {e}" ) raise e for ( node_type, inferencer_output_paths, ) in node_type_to_inferencer_output_paths_map.items(): condensed_node_type = self.__gbml_config_pb_wrapper.graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[ node_type ] should_run_unenumeration = bool( self.__gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[ condensed_node_type ].enumerated_node_ids_bq_table ) temp_predictions_gcs_path = ( inferencer_output_paths.temp_predictions_gcs_path ) temp_embeddings_gcs_path = inferencer_output_paths.temp_embedding_gcs_path bq_inferencer_output_paths = ( inferencer_output_paths.bq_inferencer_output_paths ) if temp_predictions_gcs_path is not None: self.write_from_gcs_to_bq( schema=inference_blueprint.get_pred_table_schema(should_run_unenumeration=should_run_unenumeration).schema, # type: ignore gcs_uri=temp_predictions_gcs_path, bq_table_uri=bq_inferencer_output_paths.predictions_path, ) if temp_embeddings_gcs_path is not None: self.write_from_gcs_to_bq( schema=inference_blueprint.get_emb_table_schema(should_run_unenumeration=should_run_unenumeration).schema, # type: ignore gcs_uri=temp_embeddings_gcs_path, bq_table_uri=bq_inferencer_output_paths.embeddings_path, ) @flushes_metrics(get_metrics_service_instance_fn=get_metrics_service_instance) @profileit( metric_name=TIMER_INFERENCER_S, get_metrics_service_instance_fn=get_metrics_service_instance, )
[docs] def run( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, custom_worker_image_uri: Optional[str] = None, ): try: return self.__run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, custom_worker_image_uri=custom_worker_image_uri, ) except Exception as e: logger.error( "Inference failed due to a raised exception; which will follow" ) logger.error(e) logger.error(traceback.format_exc()) sys.exit(f"System will now exit: {e}")
def __init__(self, bq_gcp_project: str): self.__bq_utils = BqUtils(project=bq_gcp_project if bq_gcp_project else None)
if __name__ == "__main__":
[docs] parser = argparse.ArgumentParser(description="Program to run distributed inference")
parser.add_argument( "--job_name", type=str, help="Unique identifier for the job name", required=True, ) parser.add_argument( "--task_config_uri", type=str, help="Gbml config uri", required=True, ) parser.add_argument( "--resource_config_uri", type=str, help="Runtime argument for resource and env specifications of each component", required=True, ) parser.add_argument( "--custom_worker_image_uri", type=str, help="Docker image to use for the worker harness in dataflow", required=False, ) parser.add_argument( "--cpu_docker_uri", type=str, help="User Specified or KFP compiled Docker Image for CPU inference", required=False, ) parser.add_argument( "--cuda_docker_uri", type=str, help="User Specified or KFP compiled Docker Image for GPU inference", required=False, ) args = parser.parse_args() task_config_uri = UriFactory.create_uri(args.task_config_uri) resource_config_uri = UriFactory.create_uri(args.resource_config_uri) custom_worker_image_uri = args.custom_worker_image_uri initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) applied_task_identifier = AppliedTaskIdentifier(args.job_name) inferencer = InferencerV1(bq_gcp_project=get_resource_config().project) inferencer.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, custom_worker_image_uri=custom_worker_image_uri, )