Source code for gigl.env.pipelines_config

import argparse
import json
import os
from typing import Optional

from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
    GiglResourceConfigWrapper,
)
from snapchat.research.gbml.gigl_resource_config_pb2 import GiglResourceConfig

[docs] logger = Logger()
def _try_loading_resource_config_uri_from_pipeline_options() -> Optional[str]: """ Tries to load the resource config URI from the pipeline options. Returns the resource config path if found, otherwise returns None. """ logger.info( "Could not find resource config path from parsed args... Assuming running Dataflow job" ) try: display_data = json.loads(os.environ.get("PIPELINE_OPTIONS", "{}")).get( "display_data", [] ) resource_config_path = next( ( item["value"] for item in display_data if item.get("key") == "resource_config_uri" ), None, ) logger.info(f"Found resource config path: {resource_config_path}") except json.JSONDecodeError: logger.error("Failed to decode PIPELINE_OPTIONS as JSON.") resource_config_path = None return resource_config_path _resource_config: Optional[GiglResourceConfigWrapper] = None
[docs] def get_resource_config( resource_config_uri: Optional[Uri] = None, ) -> GiglResourceConfigWrapper: """ Function call to return a resource config wrapper object Usage: resource_config = get_resource_config() print(resource_config.trainer_config) Args: resource_config_uri: Optional[Uri] = None The URI of the resource config file. If None, the function will try to load the resource config from the command-line argument --resource_config_uri or the environment variable RESOURCE_CONFIG_PATH. If these are not set, the function will try to load the resource config from the pipeline options. Returns: resource_config: GiglResourceConfigWrapper The resource config wrapper object """ global _resource_config if _resource_config is not None: return _resource_config resource_config_str = None if resource_config_uri is not None: resource_config_str = str(resource_config_uri) else: parser = argparse.ArgumentParser() parser.add_argument( "--resource_config_uri", type=str, required=False, ) args, _ = parser.parse_known_args() resource_config_str = args.resource_config_uri or os.getenv( "RESOURCE_CONFIG_PATH" ) if resource_config_str is None: resource_config_str = ( _try_loading_resource_config_uri_from_pipeline_options() ) if resource_config_str is None: raise ValueError( "No resource config provided, either via command-line argument or environment variable." ) os.environ["RESOURCE_CONFIG_PATH"] = resource_config_str resource_config_path = UriFactory.create_uri(uri=resource_config_str) from gigl.common.utils.proto_utils import ProtoUtils proto_utils = ProtoUtils() _resource_config = GiglResourceConfigWrapper( proto_utils.read_proto_from_yaml( resource_config_path, proto_cls=GiglResourceConfig ) ) return _resource_config
[docs] def is_resource_config_loaded() -> bool: """ Checks if the resource config has been loaded. Returns True if the resource config has been loaded, False otherwise. """ return _resource_config is not None
if __name__ == "__main__":
[docs] resource_config = get_resource_config()
print(resource_config)