Source code for gigl.src.common.utils.gigl_runtime

"""Runtime initialization helpers for GiGL component entrypoints."""

import os
from typing import Optional

from gigl.common import Uri
from gigl.env.constants import (
    GIGL_CPU_DOCKER_URI_ENV_KEY,
    GIGL_CUDA_DOCKER_URI_ENV_KEY,
)
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.utils.gigl_env import get_gigl_runtime_env_vars
from gigl.src.common.utils.metrics_service_provider import initialize_metrics


[docs] def initialize_gigl_runtime( applied_task_identifier: str, task_config_uri: Uri, resource_config_uri: Uri, service_name: str, component: GiGLComponents, cpu_docker_uri: Optional[str] = None, cuda_docker_uri: Optional[str] = None, ) -> None: """Initialize GiGL runtime environment and metrics for a component. For ``SubgraphSampler`` and ``SplitGenerator`` only metrics are initialized; runtime env vars are not set, since these legacy (Scala/Spark) components do not consume the GiGL Python runtime. Args: applied_task_identifier: Unique identifier for the GiGL job. task_config_uri: URI to the task config YAML file. resource_config_uri: URI to the resource config YAML file. service_name: Name of the service, used for metric grouping. component: GiGL component being initialized. cpu_docker_uri: CPU source image URI. Defaults to the release CPU image. cuda_docker_uri: CUDA source image URI. Defaults to the release CUDA image. """ if component in {GiGLComponents.SubgraphSampler, GiGLComponents.SplitGenerator}: initialize_metrics(task_config_uri=task_config_uri, service_name=service_name) return # TODO(kmonte): Also expose the dataflow docker URI (used as custom_worker_image_uri by # DataPreprocessor/Inferencer) as a GIGL_DATAFLOW_DOCKER_URI env var for parity with the # CPU/CUDA docker URIs. Requires a new key in gigl/env/constants.py and threading it # through get_gigl_runtime_env_vars. resolved_cpu_docker_uri = ( os.environ.get(GIGL_CPU_DOCKER_URI_ENV_KEY) if cpu_docker_uri is None else cpu_docker_uri ) resolved_cuda_docker_uri = ( os.environ.get(GIGL_CUDA_DOCKER_URI_ENV_KEY) if cuda_docker_uri is None else cuda_docker_uri ) os.environ.update( get_gigl_runtime_env_vars( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, component=component, cpu_docker_uri=resolved_cpu_docker_uri, cuda_docker_uri=resolved_cuda_docker_uri, ) ) initialize_metrics(task_config_uri=task_config_uri, service_name=service_name)