Source code for gigl.orchestration.local.runner
from collections import OrderedDict
from dataclasses import dataclass
from typing import Callable, Optional
from gigl.common import Uri
from gigl.common.constants import DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU
from gigl.common.logger import Logger
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.utils.metrics_service_provider import initialize_metrics
from gigl.src.config_populator.config_populator import ConfigPopulator
from gigl.src.data_preprocessor.data_preprocessor import DataPreprocessor
from gigl.src.inference.inferencer import Inferencer
from gigl.src.split_generator.split_generator import SplitGenerator
from gigl.src.subgraph_sampler.subgraph_sampler import SubgraphSampler
from gigl.src.training.trainer import Trainer
from gigl.src.validation_check.config_validator import (
    START_COMPONENT_TO_ASSET_CHECKS_MAP,
    START_COMPONENT_TO_CLS_CHECKS_MAP,
)
from snapchat.research.gbml import gbml_config_pb2
@dataclass
[docs]
class PipelineConfig:
    """
    Configuration for the GiGL pipeline.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): your job name
        task_config_uri (Uri): URI to your template task config
        resource_config_uri (Uri): URI to your resource config
        custom_cuda_docker_uri (Optional[str]): For custom training spec and GPU training on VertexAI
        custom_cpu_docker_uri (Optional[str]): For custom training spec and CPU training on VertexAI
        dataflow_docker_uri (Optional[str]): For custom datapreprocessor spec that will run in dataflow
    """
[docs]
    applied_task_identifier: AppliedTaskIdentifier 
[docs]
    resource_config_uri: Uri 
[docs]
    custom_cuda_docker_uri: Optional[str] = None 
[docs]
    custom_cpu_docker_uri: Optional[str] = None 
[docs]
    dataflow_docker_uri: Optional[str] = DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU 
 
[docs]
class Runner:
    """
    Orchestration of GiGL Pipeline with local execution.
    Args:
        pipeline_config (PipelineConfig): Configuration for the pipeline.
        start_at (str): Component to start the pipeline from. Default is config_populator.
    """
    @staticmethod
[docs]
    def run(
        pipeline_config: PipelineConfig,
        start_at: str = GiGLComponents.ConfigPopulator.value,
    ):
        """
        Runs the GiGL pipeline locally starting from the specified component.
        Args:
            pipeline_config (PipelineConfig): Configuration for the pipeline.
            start_at (str): Component to start the pipeline from. Defaults to 'config_populator'.
        Returns:
            None
        """
        logger.info(
            f"Running pipeline from component {start_at} with parameters: \n"
            f"job_name: {pipeline_config.applied_task_identifier}\n"
            f"task_config_uri: {pipeline_config.task_config_uri}\n"
            f"resource_config_uri: {pipeline_config.resource_config_uri}\n"
            f"dataflow_docker_uri: {pipeline_config.dataflow_docker_uri}"
        )
        initialize_metrics(
            task_config_uri=pipeline_config.task_config_uri,
            service_name=pipeline_config.applied_task_identifier,
        )
        if start_at == GiGLComponents.ConfigPopulator.value:
            frozen_config_uri = Runner.run_config_populator(pipeline_config)
            pipeline_config.task_config_uri = frozen_config_uri
        else:
            Runner.config_check(start_at, pipeline_config)
        component_map: OrderedDict[GiGLComponents, Callable] = OrderedDict(
            {
                GiGLComponents.ConfigPopulator.value: Runner.run_config_populator,
                GiGLComponents.DataPreprocessor.value: Runner.run_data_preprocessor,
                GiGLComponents.SubgraphSampler.value: Runner.run_subgraph_sampler,
                GiGLComponents.SplitGenerator.value: Runner.run_split_generator,
                GiGLComponents.Trainer.value: Runner.run_trainer,
                GiGLComponents.Inferencer.value: Runner.run_inferencer,
            }
        )
        started: bool = False
        for component, method in component_map.items():
            if component == start_at:
                started = True
            if started:
                method(pipeline_config) 
    @staticmethod
[docs]
    def config_check(start_at: str, pipeline_config: PipelineConfig):
        proto_utils = ProtoUtils()
        gbml_config_pb: gbml_config_pb2.GbmlConfig = proto_utils.read_proto_from_yaml(
            uri=pipeline_config.task_config_uri, proto_cls=gbml_config_pb2.GbmlConfig
        )
        for cls_check in START_COMPONENT_TO_CLS_CHECKS_MAP.get(start_at, []):
            cls_check(gbml_config_pb=gbml_config_pb)
        for asset_check in START_COMPONENT_TO_ASSET_CHECKS_MAP.get(start_at, []):
            asset_check(gbml_config_pb=gbml_config_pb) 
    @staticmethod
[docs]
    def run_config_populator(pipeline_config: PipelineConfig) -> Uri:
        logger.info("Running Config Populator...")
        config_populator = ConfigPopulator()
        return config_populator.run(
            applied_task_identifier=pipeline_config.applied_task_identifier,
            task_config_uri=pipeline_config.task_config_uri,
            resource_config_uri=pipeline_config.resource_config_uri,
        ) 
    @staticmethod
[docs]
    def run_data_preprocessor(pipeline_config: PipelineConfig) -> None:
        logger.info("Running Data Preprocessor...")
        data_preprocessor = DataPreprocessor()
        data_preprocessor.run(
            applied_task_identifier=pipeline_config.applied_task_identifier,
            task_config_uri=pipeline_config.task_config_uri,
            resource_config_uri=pipeline_config.resource_config_uri,
            custom_worker_image_uri=pipeline_config.dataflow_docker_uri,
        ) 
    @staticmethod
[docs]
    def run_subgraph_sampler(pipeline_config: PipelineConfig) -> None:
        logger.info("Running Subgraph Sampler...")
        subgraph_sampler = SubgraphSampler()
        subgraph_sampler.run(
            applied_task_identifier=pipeline_config.applied_task_identifier,
            task_config_uri=pipeline_config.task_config_uri,
            resource_config_uri=pipeline_config.resource_config_uri,
        ) 
    @staticmethod
[docs]
    def run_split_generator(pipeline_config: PipelineConfig) -> None:
        logger.info("Running Split Generator...")
        split_generator = SplitGenerator()
        split_generator.run(
            applied_task_identifier=pipeline_config.applied_task_identifier,
            task_config_uri=pipeline_config.task_config_uri,
            resource_config_uri=pipeline_config.resource_config_uri,
        ) 
    @staticmethod
[docs]
    def run_trainer(pipeline_config: PipelineConfig) -> None:
        logger.info("Running Trainer...")
        trainer = Trainer()
        trainer.run(
            applied_task_identifier=pipeline_config.applied_task_identifier,
            task_config_uri=pipeline_config.task_config_uri,
            resource_config_uri=pipeline_config.resource_config_uri,
            cpu_docker_uri=pipeline_config.custom_cpu_docker_uri,
            cuda_docker_uri=pipeline_config.custom_cuda_docker_uri,
        ) 
    @staticmethod
[docs]
    def run_inferencer(pipeline_config: PipelineConfig) -> None:
        logger.info("Running Inferencer...")
        inferencer = Inferencer()
        inferencer.run(
            applied_task_identifier=pipeline_config.applied_task_identifier,
            task_config_uri=pipeline_config.task_config_uri,
            resource_config_uri=pipeline_config.resource_config_uri,
            custom_worker_image_uri=pipeline_config.dataflow_docker_uri,
            cpu_docker_uri=pipeline_config.custom_cpu_docker_uri,
            cuda_docker_uri=pipeline_config.custom_cuda_docker_uri,
        )