Source code for gigl.src.subgraph_sampler.subgraph_sampler

import argparse
import datetime
import os
from distutils.util import strtobool
from typing import Optional, Sequence

import gigl.env.dep_constants as dep_constants
import gigl.src.common.constants.gcs as gcs_constants
import gigl.src.common.constants.metrics as metrics_constants
from gigl.common import GcsUri, LocalUri, Uri, UriFactory
from gigl.common.constants import (
    SPARK_31_TFRECORD_JAR_GCS_PATH,
    SPARK_35_TFRECORD_JAR_GCS_PATH,
)
from gigl.common.logger import Logger
from gigl.common.metrics.decorators import flushes_metrics, profileit
from gigl.common.utils import os_utils
from gigl.common.utils.gcs import GcsUtils
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.constants.metrics import TIMER_SUBGRAPH_SAMPLER_S
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.file_loader import FileLoader
from gigl.src.common.utils.metrics_service_provider import (
    get_metrics_service_instance,
    initialize_metrics,
)
from gigl.src.common.utils.spark_job_manager import (
    DataprocClusterInitData,
    SparkJobManager,
)
from gigl.src.subgraph_sampler.lib.ingestion_protocol import BaseIngestion

[docs] logger = Logger()
[docs] MAX_JOB_DURATION = datetime.timedelta( hours=5 ) # Allowed max job duration for SGS job -- for MAU workload
[docs] class SubgraphSampler: """ GiGL Component that generates k-hop localized subgraphs for each node in the graph using Spark/Scala running on Dataproc. """ def __prepare_staging_paths( self, applied_task_identifier: AppliedTaskIdentifier, gbml_config_pb_wrapper: GbmlConfigPbWrapper, ) -> None: # Clear paths that Subgraph Sampler would be writing to, to avoid clobbering of data. # Some of these paths are inferred from paths specified in the GbmlConfig. # Other paths are inferred from the AppliedTaskIdentifier. logger.info("Preparing staging paths for Subgraph Sampler...") paths_to_delete = ( [ gcs_constants.get_subgraph_sampler_root_dir( applied_task_identifier=applied_task_identifier ) ] + gbml_config_pb_wrapper.flattened_graph_metadata_pb_wrapper.get_output_paths() ) file_loader = FileLoader() file_loader.delete_files(uris=paths_to_delete) logger.info("Staging paths for Subgraph Sampler prepared.") @flushes_metrics(get_metrics_service_instance_fn=get_metrics_service_instance) @profileit( metric_name=TIMER_SUBGRAPH_SAMPLER_S, get_metrics_service_instance_fn=get_metrics_service_instance, )
[docs] def run( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, resource_config_uri: Uri, cluster_name: Optional[str] = None, debug_cluster_owner_alias: Optional[str] = None, custom_worker_image_uri: Optional[str] = None, skip_cluster_delete: bool = False, additional_spark35_jar_file_uris: Sequence[Uri] = (), ): resource_config = get_resource_config(resource_config_uri=resource_config_uri) gbml_config_pb_wrapper: GbmlConfigPbWrapper = ( GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=task_config_uri ) ) self.__prepare_staging_paths( applied_task_identifier=applied_task_identifier, gbml_config_pb_wrapper=gbml_config_pb_wrapper, ) use_graph_db = ( gbml_config_pb_wrapper.dataset_config.subgraph_sampler_config.HasField( "graph_db_config" ) ) use_spark35: bool = bool( strtobool( gbml_config_pb_wrapper.dataset_config.subgraph_sampler_config.experimental_flags.get( "use_spark35_runner", "False" ) ) ) if use_graph_db: # Run spark35 runner if we're using graphdb version of the subgraph sampler logger.info( "Will default to using Spark 3.5 runner for Subgraph Sampler since graph_db_config is set" ) use_spark35 = True metrics_service = get_metrics_service_instance() if metrics_service is not None: metrics_service.add_count( metric_name=metrics_constants.COUNT_SGS_USES_GRAPHDB, count=1 ) should_ingest_into_graph_db: bool = use_graph_db and getattr( gbml_config_pb_wrapper.dataset_config.subgraph_sampler_config.graph_db_config, "graph_db_ingestion_cls_path", ) if should_ingest_into_graph_db: graph_db_config = ( gbml_config_pb_wrapper.dataset_config.subgraph_sampler_config.graph_db_config ) graph_db_ingestion_config_cls_str: str = ( graph_db_config.graph_db_ingestion_cls_path ) graph_db_ingestion_args = graph_db_config.graph_db_ingestion_args # type: ignore graph_db_args = graph_db_config.graph_db_args all_graph_db_args = {**graph_db_ingestion_args, **graph_db_args} graph_db_ingestion_cls = os_utils.import_obj( obj_path=graph_db_ingestion_config_cls_str ) try: graph_db_ingestion_config: BaseIngestion = graph_db_ingestion_cls( **all_graph_db_args ) except Exception as e: logger.error( f"Could not instantiate class {graph_db_ingestion_cls} with args {graph_db_ingestion_args}" ) raise e logger.info( f"Instantiated {graph_db_ingestion_cls} with args {graph_db_ingestion_args}" ) logger.info("Running ingestion...") graph_db_ingestion_config.ingest( gbml_config_pb_wrapper=gbml_config_pb_wrapper, resource_config_uri=resource_config_uri, applied_task_identifier=applied_task_identifier, custom_worker_image_uri=custom_worker_image_uri, ) logger.info("Ingestion complete. Cleaning up...") graph_db_ingestion_config.clean_up() # must match pattern (?:[a-z](?:[-a-z0-9]{0,49}[a-z0-9])?) if not cluster_name: cluster_name = f"sgs_{applied_task_identifier}" cluster_name = cluster_name.replace("_", "-")[:50] if cluster_name.endswith("-"): cluster_name = cluster_name[:-1] + "Z" gcs_utils = GcsUtils() main_jar_file_uri: LocalUri = dep_constants.get_jar_file_uri( component=GiGLComponents.SubgraphSampler, use_spark35=use_spark35 ) logger.info(f"Using main jar file: {main_jar_file_uri}") main_jar_file_name: str = main_jar_file_uri.uri.split("/")[-1] jar_file_local_dir: str = os.path.dirname(main_jar_file_uri.uri) logger.info(f"Jar file local dir: {jar_file_local_dir}") jar_file_gcs_bucket: GcsUri = gcs_constants.get_subgraph_sampler_root_dir( applied_task_identifier=applied_task_identifier ) jars_to_upload: dict[Uri, GcsUri] = { main_jar_file_uri: GcsUri.join(jar_file_gcs_bucket, main_jar_file_name) } # Since Spark 3.5 and Spark 3.1 are using different versions of Scala # We need to pass the correct extra jar file to the Spark cluster, # Otherwise, we may see some errors like: # java.io.InvalidClassException; local class incompatible: stream classdesc serialVersionUID = -1, local class serialVersionUID = 2 if use_spark35: for jar_uri in additional_spark35_jar_file_uris: jars_to_upload[jar_uri] = GcsUri.join( jar_file_gcs_bucket, jar_uri.get_basename() ) sgs_jar_file_gcs_path = GcsUri.join( jar_file_gcs_bucket, main_jar_file_name, ) logger.info(f"Uploading jar files {jars_to_upload}") FileLoader().load_files(source_to_dest_file_uri_map=jars_to_upload) extra_jar_file_uris = [ jars_to_upload[jar].uri for jar in jars_to_upload if jar != main_jar_file_uri ] if use_spark35: extra_jar_file_uris.append(SPARK_35_TFRECORD_JAR_GCS_PATH) else: extra_jar_file_uris.append(SPARK_31_TFRECORD_JAR_GCS_PATH) logger.info(f"Will add the following jars to all jobs: {extra_jar_file_uris}") resource_config_gcs_path: GcsUri task_config_gcs_path: GcsUri file_loader = FileLoader() if not isinstance(resource_config_uri, GcsUri): resource_config_gcs_path = GcsUri.join( gcs_constants.get_applied_task_temp_gcs_path( applied_task_identifier=applied_task_identifier ), "resource_config.yaml", ) logger.info( f"Uploading resource config : {resource_config_uri} to gcs: {resource_config_gcs_path}" ) file_loader.load_file( file_uri_src=resource_config_uri, file_uri_dst=resource_config_gcs_path ) else: resource_config_gcs_path = resource_config_uri if not isinstance(task_config_uri, GcsUri): task_config_gcs_path = GcsUri.join( gcs_constants.get_applied_task_temp_gcs_path( applied_task_identifier=applied_task_identifier ), "task_config.yaml", ) logger.info( f"Uploading task config : {task_config_uri} to gcs: {task_config_gcs_path}" ) file_loader.load_file( file_uri_src=task_config_uri, file_uri_dst=task_config_gcs_path ) else: task_config_gcs_path = task_config_uri logger.info( f"Using resource config: {resource_config_gcs_path} and task config: {task_config_gcs_path}" ) dataproc_helper = SparkJobManager( project=resource_config.project, region=resource_config.region, cluster_name=cluster_name, ) cluster_init_data = DataprocClusterInitData( project=resource_config.project, region=resource_config.region, service_account=resource_config.service_account_email, cluster_name=cluster_name, machine_type=resource_config.subgraph_sampler_config.machine_type, temp_assets_bucket=str(resource_config.temp_assets_regional_bucket_path), num_workers=resource_config.subgraph_sampler_config.num_replicas, num_local_ssds=resource_config.subgraph_sampler_config.num_local_ssds, debug_cluster_owner_alias=debug_cluster_owner_alias, is_debug_mode=skip_cluster_delete or bool(debug_cluster_owner_alias), labels=resource_config.get_resource_labels( component=GiGLComponents.SubgraphSampler ), ) if use_spark35: logger.warning( "You are using Spark 3.5 runner for Subgraph Sampler, not all features are supported yet." ) dataproc_helper.create_dataproc_cluster( cluster_init_data=cluster_init_data, use_spark35=use_spark35, ) dataproc_helper.submit_and_wait_scala_spark_job( main_jar_file_uri=sgs_jar_file_gcs_path.uri, max_job_duration=MAX_JOB_DURATION, runtime_args=[ task_config_gcs_path.uri, applied_task_identifier, resource_config_gcs_path.uri, ], extra_jar_file_uris=extra_jar_file_uris, use_spark35=use_spark35, ) if not skip_cluster_delete: logger.info( f"skip_cluster_delete marked to {skip_cluster_delete}; will delete cluster" ) dataproc_helper.delete_cluster()
if __name__ == "__main__":
[docs] parser = argparse.ArgumentParser( description="Program to sample subgraphs from preprocessed graph/feature data" + "Using the subgraphs, generates samples that can be consumed by rest of the pipeline." )
parser.add_argument( "--job_name", type=str, help="Unique identifier for the job name", required=True, ) parser.add_argument( "--task_config_uri", type=str, help="Gbml frozen config uri", required=True, ) parser.add_argument( "--resource_config_uri", type=str, help="Runtime argument for resource and env specifications of each component", required=True, ) parser.add_argument( "--cluster_name", type=str, help="Optional param if you want to re-use a cluster for continous development purposes." + "Otherwise, a cluster name will automatically be generated based on job_name", required=False, ) parser.add_argument( "--skip_cluster_delete", action="store_true", help="Provide flag to skip automatic cleanup of dataproc cluster. This way you can re-use the cluster for development purposes", default=False, ) parser.add_argument( "--debug_cluster_owner_alias", type=str, help="debug_cluster_owner_alias", required=False, ) parser.add_argument( "--custom_worker_image_uri", type=str, help="Docker image to use for the worker harness in dataflow jobs (optional)", required=False, ) parser.add_argument( "--additional_spark35_jar_file_uris", action="append", type=str, required=False, default=[], help="Additional URIs to be added to the Spark cluster.", ) args = parser.parse_args() if not args.job_name or not args.task_config_uri or not args.resource_config_uri: raise RuntimeError( f"Missing command-line arguments, expected all of [job_name, task_config_uri, resource_config_uri]. Received: {args}" ) ati = AppliedTaskIdentifier(args.job_name) task_config_uri = UriFactory.create_uri(uri=args.task_config_uri) resource_config_uri = UriFactory.create_uri(uri=args.resource_config_uri) applied_task_identifier = AppliedTaskIdentifier(args.job_name) custom_worker_image_uri = args.custom_worker_image_uri initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) sgs = SubgraphSampler() sgs.run( applied_task_identifier=ati, task_config_uri=task_config_uri, cluster_name=args.cluster_name, debug_cluster_owner_alias=args.debug_cluster_owner_alias, skip_cluster_delete=args.skip_cluster_delete, resource_config_uri=resource_config_uri, custom_worker_image_uri=custom_worker_image_uri, # Filter out empty strings which kfp *may* add... additional_spark35_jar_file_uris=[ UriFactory.create_uri(jar) for jar in args.additional_spark35_jar_file_uris if jar ], )