import argparse
from typing import Optional
from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
    GiglResourceConfigWrapper,
)
from gigl.src.common.utils.metrics_service_provider import initialize_metrics
from gigl.src.inference.lib.assets import InferenceAssets
from gigl.src.inference.v1.gnn_inferencer import InferencerV1
from gigl.src.inference.v2.glt_inferencer import GLTInferencer
[docs]
class Inferencer:
    """
    GiGL Component that runs static (GiGL) or dynamic (GLT) inference of a trained model on samples and outputs embedding and/or prediction assets.
    """
[docs]
    def run(
        self,
        applied_task_identifier: AppliedTaskIdentifier,
        task_config_uri: Uri,
        resource_config_uri: Uri,
        custom_worker_image_uri: Optional[str] = None,
        cpu_docker_uri: Optional[str] = None,
        cuda_docker_uri: Optional[str] = None,
    ):
        gbml_config_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
            gbml_config_uri=task_config_uri
        )
        resource_config_wrapper: GiglResourceConfigWrapper = get_resource_config(
            resource_config_uri=resource_config_uri
        )
        # Prepare staging paths for inferencer assets by clearing the paths that inferencer
        # would be writing to, to avoid clobbering of data.
        InferenceAssets.prepare_staging_paths(
            applied_task_identifier=applied_task_identifier,
            gbml_config_pb_wrapper=gbml_config_wrapper,
            project=resource_config_wrapper.project,
        )
        if gbml_config_wrapper.should_use_glt_backend:
            inferencer_glt = GLTInferencer()
            inferencer_glt.run(
                applied_task_identifier=applied_task_identifier,
                task_config_uri=task_config_uri,
                resource_config_uri=resource_config_uri,
                cpu_docker_uri=cpu_docker_uri,
                cuda_docker_uri=cuda_docker_uri,
            )
        else:
            inferencer_v1 = InferencerV1(bq_gcp_project=resource_config_wrapper.project)
            inferencer_v1.run(
                applied_task_identifier=applied_task_identifier,
                task_config_uri=task_config_uri,
                custom_worker_image_uri=custom_worker_image_uri,
            ) 
 
if __name__ == "__main__":
[docs]
    parser = argparse.ArgumentParser(description="Program to run distributed inference") 
    parser.add_argument(
        "--job_name",
        type=str,
        help="Unique identifier for the job name",
        required=True,
    )
    parser.add_argument(
        "--task_config_uri",
        type=str,
        help="Gbml config uri",
        required=True,
    )
    parser.add_argument(
        "--resource_config_uri",
        type=str,
        help="Runtime argument for resource and env specifications of each component",
        required=True,
    )
    parser.add_argument(
        "--custom_worker_image_uri",
        type=str,
        help="Docker image to use for the worker harness in dataflow",
        required=False,
    )
    parser.add_argument(
        "--cpu_docker_uri",
        type=str,
        help="User Specified or KFP compiled Docker Image for CPU inference",
        required=False,
    )
    parser.add_argument(
        "--cuda_docker_uri",
        type=str,
        help="User Specified or KFP compiled Docker Image for GPU inference",
        required=False,
    )
    args = parser.parse_args()
    task_config_uri = UriFactory.create_uri(args.task_config_uri)
    resource_config_uri = UriFactory.create_uri(args.resource_config_uri)
    custom_worker_image_uri = args.custom_worker_image_uri
    cpu_docker_uri = args.cpu_docker_uri
    cuda_docker_uri = args.cuda_docker_uri
    initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name)
    applied_task_identifier = AppliedTaskIdentifier(args.job_name)
    inferencer = Inferencer()
    inferencer.run(
        applied_task_identifier=applied_task_identifier,
        task_config_uri=task_config_uri,
        resource_config_uri=resource_config_uri,
        custom_worker_image_uri=custom_worker_image_uri,
        cpu_docker_uri=cpu_docker_uri,
        cuda_docker_uri=cuda_docker_uri,
    )