Source code for gigl.orchestration.local.runner

from collections import OrderedDict
from dataclasses import dataclass
from typing import Callable, Optional

from gigl.common import Uri
from gigl.common.logger import Logger
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.env.dep_constants import GIGL_DATAFLOW_IMAGE
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.utils.metrics_service_provider import initialize_metrics
from gigl.src.config_populator.config_populator import ConfigPopulator
from gigl.src.data_preprocessor.data_preprocessor import DataPreprocessor
from gigl.src.inference.inferencer import Inferencer
from gigl.src.split_generator.split_generator import SplitGenerator
from gigl.src.subgraph_sampler.subgraph_sampler import SubgraphSampler
from gigl.src.training.trainer import Trainer
from gigl.src.validation_check.config_validator import (
    START_COMPONENT_TO_ASSET_CHECKS_MAP,
    START_COMPONENT_TO_CLS_CHECKS_MAP,
)
from snapchat.research.gbml import gbml_config_pb2

[docs] logger = Logger()
@dataclass
[docs] class PipelineConfig: """ Configuration for the GiGL pipeline. Args: applied_task_identifier (AppliedTaskIdentifier): your job name task_config_uri (Uri): URI to your template task config resource_config_uri (Uri): URI to your resource config custom_cuda_docker_uri (Optional[str]): For custom training spec and GPU training on VertexAI custom_cpu_docker_uri (Optional[str]): For custom training spec and CPU training on VertexAI dataflow_docker_uri (Optional[str]): For custom datapreprocessor spec that will run in dataflow """
[docs] applied_task_identifier: AppliedTaskIdentifier
[docs] task_config_uri: Uri
[docs] resource_config_uri: Uri
[docs] custom_cuda_docker_uri: Optional[str] = None
[docs] custom_cpu_docker_uri: Optional[str] = None
[docs] dataflow_docker_uri: Optional[str] = GIGL_DATAFLOW_IMAGE
[docs] class Runner: """ Orchestration of GiGL Pipeline with local execution. Args: pipeline_config (PipelineConfig): Configuration for the pipeline. start_at (str): Component to start the pipeline from. Default is config_populator. """ @staticmethod
[docs] def run( pipeline_config: PipelineConfig, start_at: str = GiGLComponents.ConfigPopulator.value, ): """ Runs the GiGL pipeline locally starting from the specified component. Args: pipeline_config (PipelineConfig): Configuration for the pipeline. start_at (str): Component to start the pipeline from. Defaults to 'config_populator'. Returns: None """ logger.info( f"Running pipeline from component {start_at} with parameters: \n" f"job_name: {pipeline_config.applied_task_identifier}\n" f"task_config_uri: {pipeline_config.task_config_uri}\n" f"resource_config_uri: {pipeline_config.resource_config_uri}\n" f"dataflow_docker_uri: {pipeline_config.dataflow_docker_uri}" ) initialize_metrics( task_config_uri=pipeline_config.task_config_uri, service_name=pipeline_config.applied_task_identifier, ) if start_at == GiGLComponents.ConfigPopulator.value: frozen_config_uri = Runner.run_config_populator(pipeline_config) pipeline_config.task_config_uri = frozen_config_uri else: Runner.config_check(start_at, pipeline_config) component_map: OrderedDict[GiGLComponents, Callable] = OrderedDict( { GiGLComponents.ConfigPopulator.value: Runner.run_config_populator, GiGLComponents.DataPreprocessor.value: Runner.run_data_preprocessor, GiGLComponents.SubgraphSampler.value: Runner.run_subgraph_sampler, GiGLComponents.SplitGenerator.value: Runner.run_split_generator, GiGLComponents.Trainer.value: Runner.run_trainer, GiGLComponents.Inferencer.value: Runner.run_inferencer, } ) started: bool = False for component, method in component_map.items(): if component == start_at: started = True if started: method(pipeline_config)
@staticmethod
[docs] def config_check(start_at: str, pipeline_config: PipelineConfig): proto_utils = ProtoUtils() gbml_config_pb: gbml_config_pb2.GbmlConfig = proto_utils.read_proto_from_yaml( uri=pipeline_config.task_config_uri, proto_cls=gbml_config_pb2.GbmlConfig ) for cls_check in START_COMPONENT_TO_CLS_CHECKS_MAP.get(start_at, []): cls_check(gbml_config_pb=gbml_config_pb) for asset_check in START_COMPONENT_TO_ASSET_CHECKS_MAP.get(start_at, []): asset_check(gbml_config_pb=gbml_config_pb)
@staticmethod
[docs] def run_config_populator(pipeline_config: PipelineConfig) -> Uri: logger.info("Running Config Populator...") config_populator = ConfigPopulator() return config_populator.run( applied_task_identifier=pipeline_config.applied_task_identifier, task_config_uri=pipeline_config.task_config_uri, resource_config_uri=pipeline_config.resource_config_uri, )
@staticmethod
[docs] def run_data_preprocessor(pipeline_config: PipelineConfig) -> None: logger.info("Running Data Preprocessor...") data_preprocessor = DataPreprocessor() data_preprocessor.run( applied_task_identifier=pipeline_config.applied_task_identifier, task_config_uri=pipeline_config.task_config_uri, resource_config_uri=pipeline_config.resource_config_uri, custom_worker_image_uri=pipeline_config.dataflow_docker_uri, )
@staticmethod
[docs] def run_subgraph_sampler(pipeline_config: PipelineConfig) -> None: logger.info("Running Subgraph Sampler...") subgraph_sampler = SubgraphSampler() subgraph_sampler.run( applied_task_identifier=pipeline_config.applied_task_identifier, task_config_uri=pipeline_config.task_config_uri, resource_config_uri=pipeline_config.resource_config_uri, )
@staticmethod
[docs] def run_split_generator(pipeline_config: PipelineConfig) -> None: logger.info("Running Split Generator...") split_generator = SplitGenerator() split_generator.run( applied_task_identifier=pipeline_config.applied_task_identifier, task_config_uri=pipeline_config.task_config_uri, resource_config_uri=pipeline_config.resource_config_uri, )
@staticmethod
[docs] def run_trainer(pipeline_config: PipelineConfig) -> None: logger.info("Running Trainer...") trainer = Trainer() trainer.run( applied_task_identifier=pipeline_config.applied_task_identifier, task_config_uri=pipeline_config.task_config_uri, resource_config_uri=pipeline_config.resource_config_uri, cpu_docker_uri=pipeline_config.custom_cpu_docker_uri, cuda_docker_uri=pipeline_config.custom_cuda_docker_uri, )
@staticmethod
[docs] def run_inferencer(pipeline_config: PipelineConfig) -> None: logger.info("Running Inferencer...") inferencer = Inferencer() inferencer.run( applied_task_identifier=pipeline_config.applied_task_identifier, task_config_uri=pipeline_config.task_config_uri, resource_config_uri=pipeline_config.resource_config_uri, custom_worker_image_uri=pipeline_config.dataflow_docker_uri, cpu_docker_uri=pipeline_config.custom_cpu_docker_uri, cuda_docker_uri=pipeline_config.custom_cuda_docker_uri, )