Source code for gigl.src.training.v2.glt_trainer

import argparse
from typing import Optional

from google.cloud.aiplatform_v1.types import accelerator_type

from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService
from gigl.env.dep_constants import GIGL_SRC_IMAGE_CPU, GIGL_SRC_IMAGE_CUDA
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
    GiglResourceConfigWrapper,
)
from gigl.src.common.utils.metrics_service_provider import initialize_metrics
from snapchat.research.gbml.gigl_resource_config_pb2 import (
    LocalResourceConfig,
    VertexAiResourceConfig,
)

[docs] logger = Logger()
# TODO: (svij) We should parameterize this in the future
[docs] DEFAULT_VERTEX_AI_TIMEOUT_S = 60 * 60 * 3 # 3 hours
# TODO: (svij) This function may need some work cc @zfan3, @xgao4 # i.e. dataloading may happen on gpu instead of inference. Curretly, there is no # great support for gpu data loading, thus we assume inference is done on gpu and # data loading is done on cpu. This will need to be revisited. def _determine_if_cpu_training( trainer_resource_config: VertexAiResourceConfig, ) -> bool: """Determine whether CPU training is required based on the glt_training configuration.""" if ( not trainer_resource_config.gpu_type or trainer_resource_config.gpu_type == accelerator_type.AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name # type: ignore[attr-defined] # `name` is defined ): return True else: return False
[docs] class GLTTrainer: """ GiGL Component that runs a GLT Training using a provided class path """ def __execute_VAI_training( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, resource_config_uri: Uri, cpu_docker_uri: Optional[str] = None, cuda_docker_uri: Optional[str] = None, ) -> None: resource_config: GiglResourceConfigWrapper = get_resource_config( resource_config_uri=resource_config_uri ) gbml_config_pb_wrapper = ( GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=task_config_uri ) ) training_process_command = gbml_config_pb_wrapper.trainer_config.command if not training_process_command: raise ValueError( "Currently, GLT Trainer only supports training process command which" + f" was not provided in trainer config: {gbml_config_pb_wrapper.trainer_config}" ) training_process_runtime_args = ( gbml_config_pb_wrapper.trainer_config.trainer_args ) assert isinstance(resource_config.trainer_config, VertexAiResourceConfig) trainer_resource_config: VertexAiResourceConfig = resource_config.trainer_config is_cpu_training = _determine_if_cpu_training( trainer_resource_config=trainer_resource_config ) cpu_docker_uri = cpu_docker_uri or GIGL_SRC_IMAGE_CPU cuda_docker_uri = cuda_docker_uri or GIGL_SRC_IMAGE_CUDA container_uri = cpu_docker_uri if is_cpu_training else cuda_docker_uri job_args = ( [ f"--job_name={applied_task_identifier}", f"--task_config_uri={task_config_uri}", f"--resource_config_uri={resource_config_uri}", ] + ([] if is_cpu_training else ["--use_cuda"]) + ([f"--{k}={v}" for k, v in training_process_runtime_args.items()]) ) command = training_process_command.strip().split(" ") logger.info(f"Running trainer with command: {command}") vai_job_name = f"gigl_train_{applied_task_identifier}" job_config = VertexAiJobConfig( job_name=vai_job_name, container_uri=container_uri, command=command, args=job_args, environment_variables=[ {"name": "TF_CPP_MIN_LOG_LEVEL", "value": "3"}, ], machine_type=trainer_resource_config.machine_type, accelerator_type=trainer_resource_config.gpu_type.upper().replace("-", "_"), accelerator_count=trainer_resource_config.gpu_limit, replica_count=trainer_resource_config.num_replicas, labels=resource_config.get_resource_labels( component=GiGLComponents.Inferencer ), timeout_s=trainer_resource_config.timeout if trainer_resource_config.timeout else None, ) vertex_ai_service = VertexAIService( project=resource_config.project, location=resource_config.region, service_account=resource_config.service_account_email, staging_bucket=resource_config.temp_assets_regional_bucket_path.uri, ) vertex_ai_service.launch_job(job_config=job_config)
[docs] def run( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, resource_config_uri: Uri, cpu_docker_uri: Optional[str] = None, cuda_docker_uri: Optional[str] = None, ) -> None: # TODO: Support local inference run i.e. non vertex AI resource_config_wrapper: GiglResourceConfigWrapper = get_resource_config( resource_config_uri=resource_config_uri ) trainer_config = resource_config_wrapper.trainer_config if isinstance(trainer_config, LocalResourceConfig): # TODO: (svij) Implement local training raise NotImplementedError( f"Local GLT Inferencer is not yet supported, please specify a {VertexAiResourceConfig.__name__} resource config field." ) elif isinstance(trainer_config, VertexAiResourceConfig): self.__execute_VAI_training( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, ) else: raise NotImplementedError( f"Unsupported resource config for glt inference: {type(trainer_config).__name__}" )
if __name__ == "__main__":
[docs] parser = argparse.ArgumentParser( description="Program to generate embeddings from a GBML model" )
parser.add_argument( "--job_name", type=str, help="Unique identifier for the job name", required=True, ) parser.add_argument( "--task_config_uri", type=str, help="A URI pointing to a GbmlConfig proto serialized as YAML", required=True, ) parser.add_argument( "--resource_config_uri", type=str, help="A URI pointing to a GiGLResourceConfig proto serialized as YAML", required=True, ) parser.add_argument( "--cpu_docker_uri", type=str, help="User Specified or KFP compiled Docker Image for CPU training", required=False, ) parser.add_argument( "--cuda_docker_uri", type=str, help="User Specified or KFP compiled Docker Image for GPU training", required=False, ) args = parser.parse_args() applied_task_identifier = AppliedTaskIdentifier(args.job_name) task_config_uri = UriFactory.create_uri(args.task_config_uri) resource_config_uri = UriFactory.create_uri(args.resource_config_uri) cpu_docker_uri, cuda_docker_uri = args.cpu_docker_uri, args.cuda_docker_uri initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) glt_inferencer = GLTTrainer() glt_inferencer.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, )