Source code for gigl.src.common.utils.dataflow

from typing import Any, Optional

from apache_beam.options.pipeline_options import (
    DebugOptions,
    GoogleCloudOptions,
    PipelineOptions,
    SetupOptions,
    StandardOptions,
    WorkerOptions,
)

from gigl.common import UriFactory
from gigl.common.logger import Logger
from gigl.env.dep_constants import GIGL_DATAFLOW_IMAGE
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants import gcs as gcs_constants
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.dataflow_job_options import CommonOptions

[docs] logger = Logger()
[docs] def get_sanitized_dataflow_job_name(name: str) -> str: name = name.lower() name = name.replace("_", "-") name = "".join([c for c in name if c.isalnum() or c == "-"]) logger.info(f"Will use sanitized dataflow job name: {name}") return name
[docs] def init_beam_pipeline_options( applied_task_identifier: AppliedTaskIdentifier, job_name_suffix: str, component: Optional[GiGLComponents] = None, custom_worker_image_uri: Optional[str] = None, **kwargs: Any, ) -> PipelineOptions: """Can pass in any options i.e. init_beam_pipeline_options(num_workers=1, max_num_workers=32, ...) The options passed in will override default options if we define them. For example, you can override the job_name by passing in `job_name="something"` Args: applied_task_identifier (AppliedTaskIdentifier) job_name_suffix (str): Unique identifier for the dataflow job in relation to this task (applied_task_identifier) i.e. job_name_suffix = "inference" Returns: PipelineOptions: options you can use to generate the pipeline """ job_name = get_sanitized_dataflow_job_name( f"gigl-{applied_task_identifier}-{job_name_suffix}" ) options = PipelineOptions(**kwargs) common_options = options.view_as(CommonOptions) resource_config_uri = UriFactory.create_uri( uri=get_resource_config().get_resource_config_uri ) common_options.resource_config_uri = get_resource_config().get_resource_config_uri # https://cloud.google.com/dataflow/docs/guides/build-container-image#pre-build_using_a_dockerfile setup_options = options.view_as(SetupOptions) setup_options.sdk_location = "container" worker_options: WorkerOptions = options.view_as(WorkerOptions) worker_options.sdk_container_image = custom_worker_image_uri or GIGL_DATAFLOW_IMAGE debug_options = options.view_as(DebugOptions) debug_options.experiments = debug_options.experiments or [ "shuffle_mode=service", "use_runner_v2", "enable_stackdriver_agent_metrics", # Allows you to increase the size of your job graph to more than 10MB # Temporarily circumventing large job graphs; ideally we should try to limit this but applying band-aid for now. # https://cloud.google.com/knowledge/kb/dataflow-job-fails-with-error-message-for-large-job-graphs-000007130 "upload_graph", ] standard_options = options.view_as(StandardOptions) standard_options.runner = ( standard_options.runner or get_resource_config().dataflow_runner ) google_cloud_options = options.view_as(GoogleCloudOptions) google_cloud_options.labels = google_cloud_options.labels or ( get_resource_config().get_resource_labels_formatted_for_dataflow( component=component ) ) google_cloud_options.project = ( google_cloud_options.project or get_resource_config().project ) google_cloud_options.job_name = job_name google_cloud_options.staging_location = ( google_cloud_options.staging_location or gcs_constants.get_dataflow_staging_gcs_path( applied_task_identifier=applied_task_identifier, job_name=job_name ).uri ) google_cloud_options.temp_location = ( google_cloud_options.temp_location or gcs_constants.get_dataflow_temp_gcs_path( applied_task_identifier=applied_task_identifier, job_name=job_name ).uri ) google_cloud_options.region = ( google_cloud_options.region or get_resource_config().region ) # For context see: https://cloud.google.com/dataflow/docs/reference/service-options#python # This is different than how `num_workers` is leveraged by dataflow in the default `PipelineOptions` exposed by beam. # i.e. simply setting `num_workers` in `PipelineOptions`, the dataflow service still may downscale to 1 worker. # vs. setting `min_num_workers` in `dataflow_service_options` explicitly will ensure that the service will not downscale below # that number. if kwargs.get("num_workers"): num_workers = kwargs.get("num_workers") logger.info( f"Setting `min_num_workers` for Dataflow explicitly to {num_workers}" ) dataflow_service_options = google_cloud_options.dataflow_service_options or [] dataflow_service_options.append(f"min_num_workers={num_workers}") google_cloud_options.dataflow_service_options = dataflow_service_options google_cloud_options.service_account_email = ( google_cloud_options.service_account_email or (get_resource_config().service_account_email) ) return options