Source code for gigl.src.inference.inferencer

import argparse
from typing import Optional

from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
    GiglResourceConfigWrapper,
)
from gigl.src.common.utils.metrics_service_provider import initialize_metrics
from gigl.src.inference.lib.assets import InferenceAssets
from gigl.src.inference.v1.gnn_inferencer import InferencerV1
from gigl.src.inference.v2.glt_inferencer import GLTInferencer

[docs] logger = Logger()
[docs] class Inferencer: """ GiGL Component that runs static (GiGL) or dynamic (GLT) inference of a trained model on samples and outputs embedding and/or prediction assets. """
[docs] def run( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, resource_config_uri: Uri, custom_worker_image_uri: Optional[str] = None, cpu_docker_uri: Optional[str] = None, cuda_docker_uri: Optional[str] = None, ): gbml_config_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=task_config_uri ) resource_config_wrapper: GiglResourceConfigWrapper = get_resource_config( resource_config_uri=resource_config_uri ) # Prepare staging paths for inferencer assets by clearing the paths that inferencer # would be writing to, to avoid clobbering of data. InferenceAssets.prepare_staging_paths( applied_task_identifier=applied_task_identifier, gbml_config_pb_wrapper=gbml_config_wrapper, project=resource_config_wrapper.project, ) if gbml_config_wrapper.should_use_glt_backend: inferencer_glt = GLTInferencer() inferencer_glt.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, ) else: inferencer_v1 = InferencerV1(bq_gcp_project=resource_config_wrapper.project) inferencer_v1.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, custom_worker_image_uri=custom_worker_image_uri, )
if __name__ == "__main__":
[docs] parser = argparse.ArgumentParser(description="Program to run distributed inference")
parser.add_argument( "--job_name", type=str, help="Unique identifier for the job name", required=True, ) parser.add_argument( "--task_config_uri", type=str, help="Gbml config uri", required=True, ) parser.add_argument( "--resource_config_uri", type=str, help="Runtime argument for resource and env specifications of each component", required=True, ) parser.add_argument( "--custom_worker_image_uri", type=str, help="Docker image to use for the worker harness in dataflow", required=False, ) parser.add_argument( "--cpu_docker_uri", type=str, help="User Specified or KFP compiled Docker Image for CPU inference", required=False, ) parser.add_argument( "--cuda_docker_uri", type=str, help="User Specified or KFP compiled Docker Image for GPU inference", required=False, ) args = parser.parse_args() task_config_uri = UriFactory.create_uri(args.task_config_uri) resource_config_uri = UriFactory.create_uri(args.resource_config_uri) custom_worker_image_uri = args.custom_worker_image_uri cpu_docker_uri = args.cpu_docker_uri cuda_docker_uri = args.cuda_docker_uri initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) applied_task_identifier = AppliedTaskIdentifier(args.job_name) inferencer = Inferencer() inferencer.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, custom_worker_image_uri=custom_worker_image_uri, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, )