Source code for gigl.src.common.utils.file_loader

import tempfile
from collections.abc import Mapping
from tempfile import _TemporaryFileWrapper as TemporaryFileWrapper  # type: ignore
from typing import Dict, List, Optional, Sequence, Tuple, Type, Union, cast

from gigl.common import GcsUri, LocalUri, Uri, UriFactory
from gigl.common.logger import Logger
from gigl.common.utils.gcs import GcsUtils
from gigl.common.utils.local_fs import (
    FileSystemEntity,
    copy_files,
    count_files_with_uri_prefix,
    create_file_symlinks,
    does_path_exist,
    list_at_path,
    remove_file_or_folder_if_exist,
)

[docs] logger = Logger()
[docs] class FileLoader: def __init__(self, project: Optional[str] = None): gcs_utils = GcsUtils(project) self.__gcs_utils = gcs_utils self.__unsupported_uri_message = ( f"{self.__class__.__name__} does not support Uris of this type." ) @staticmethod def __get_uri_map_schema( uri_map: Mapping[Uri, Uri] ) -> Tuple[Optional[Type[Uri]], Optional[Type[Uri]]]: uniform_src_type: Optional[Type[Uri]] = None uniform_dst_type: Optional[Type[Uri]] = None src_types: List[Type[Uri]] = [uri.__class__ for uri in uri_map.keys()] dst_types: List[Type[Uri]] = [uri.__class__ for uri in uri_map.values()] if all([src_types[0] == x for x in src_types]): uniform_src_type = src_types[0] if all([dst_types[0] == x for x in dst_types]): uniform_dst_type = dst_types[0] return uniform_src_type, uniform_dst_type
[docs] def load_directories( self, source_to_dest_directory_map: Dict[Uri, Uri], ): for dir_uri_src, dir_uri_dst in source_to_dest_directory_map.items(): self.load_directory(dir_uri_src=dir_uri_src, dir_uri_dst=dir_uri_dst)
[docs] def load_directory(self, dir_uri_src: Uri, dir_uri_dst: Uri): uri_map_schema = self.__get_uri_map_schema(uri_map={dir_uri_src: dir_uri_dst}) if uri_map_schema == (GcsUri, LocalUri): dir_uri_src = cast(GcsUri, dir_uri_src) dir_uri_dst = cast(LocalUri, dir_uri_dst) self.__gcs_utils.download_files_from_gcs_paths_to_local_dir( gcs_paths=[dir_uri_src], local_path_dir=dir_uri_dst ) elif uri_map_schema == (LocalUri, GcsUri): dir_uri_src = cast(LocalUri, dir_uri_src) dir_uri_dst = cast(GcsUri, dir_uri_dst) local_paths: List[LocalUri] = list_at_path( local_path=dir_uri_src, file_system_entity=FileSystemEntity.FILE ) gcs_paths: List[GcsUri] = [ GcsUri.join(dir_uri_dst, local_fn.uri) for local_fn in list_at_path(dir_uri_src, names_only=True) ] local_file_path_to_gcs_path_map: Dict[LocalUri, GcsUri] = { src: dst for src, dst in zip(local_paths, gcs_paths) } self.load_files( source_to_dest_file_uri_map=cast( Dict[Uri, Uri], local_file_path_to_gcs_path_map ) ) elif uri_map_schema == (LocalUri, LocalUri): dir_uri_src = cast(LocalUri, dir_uri_src) dir_uri_dst = cast(LocalUri, dir_uri_dst) local_src_paths: List[LocalUri] = list_at_path( local_path=dir_uri_src, file_system_entity=FileSystemEntity.FILE ) local_dst_paths: List[LocalUri] = [ LocalUri.join(dir_uri_dst, local_src_fn) for local_src_fn in list_at_path( local_path=dir_uri_src, names_only=True, file_system_entity=FileSystemEntity.FILE, ) ] source_to_dest_file_uri_map = { src: dst for src, dst in zip(local_src_paths, local_dst_paths) } self.load_files( source_to_dest_file_uri_map=cast( Dict[Uri, Uri], source_to_dest_file_uri_map ) ) else: raise TypeError(self.__unsupported_uri_message)
[docs] def load_files( self, source_to_dest_file_uri_map: Mapping[Uri, Uri], should_create_symlinks_if_possible: bool = True, ) -> None: uri_map_schema = self.__get_uri_map_schema(uri_map=source_to_dest_file_uri_map) if uri_map_schema == (GcsUri, LocalUri): logger.info("Downloading from GCS to Local") self.__gcs_utils.download_files_from_gcs_paths_to_local_paths( file_map=cast(Dict[GcsUri, LocalUri], source_to_dest_file_uri_map) ) elif uri_map_schema == (LocalUri, GcsUri): logger.info("Uploading from Local to GCS") self.__gcs_utils.upload_files_to_gcs( local_file_path_to_gcs_path_map=cast( Dict[LocalUri, GcsUri], source_to_dest_file_uri_map ), parallel=True, ) elif uri_map_schema == (LocalUri, LocalUri): logger.info("Copying from Local to Local") local_source_to_link_path_map = source_to_dest_file_uri_map if should_create_symlinks_if_possible: logger.info("Will create symlinks") create_file_symlinks( local_source_to_link_path_map=cast( Dict[LocalUri, LocalUri], local_source_to_link_path_map ), should_overwrite=True, ) else: logger.info("Will copy files") copy_files( local_source_to_local_dst_path_map=cast( Dict[LocalUri, LocalUri], local_source_to_link_path_map ), should_overwrite=True, ) else: for file_uri_src, file_uri_dst in source_to_dest_file_uri_map.items(): self.load_file( file_uri_src=file_uri_src, file_uri_dst=file_uri_dst, should_create_symlinks_if_possible=should_create_symlinks_if_possible, )
[docs] def load_file( self, file_uri_src: Uri, file_uri_dst: Uri, should_create_symlinks_if_possible: bool = True, ) -> None: uri_map_schema = self.__get_uri_map_schema(uri_map={file_uri_src: file_uri_dst}) uri_map = {file_uri_src: file_uri_dst} if uri_map_schema == (GcsUri, LocalUri): self.__gcs_utils.download_file_from_gcs( gcs_path=cast(GcsUri, file_uri_src), dest_file_path=cast(LocalUri, file_uri_dst), ) elif uri_map_schema == (LocalUri, GcsUri): self.__gcs_utils.upload_files_to_gcs( local_file_path_to_gcs_path_map=cast(Dict[LocalUri, GcsUri], uri_map), parallel=False, ) elif uri_map_schema == (LocalUri, LocalUri): local_source_to_link_path_map = {file_uri_src: file_uri_dst} if should_create_symlinks_if_possible: create_file_symlinks( local_source_to_link_path_map=cast( Dict[LocalUri, LocalUri], local_source_to_link_path_map ), should_overwrite=True, ) else: copy_files( local_source_to_local_dst_path_map=cast( Dict[LocalUri, LocalUri], local_source_to_link_path_map ), should_overwrite=True, ) else: logger.warning(f"Unsupported uri_map_schema: {uri_map_schema}") raise TypeError(self.__unsupported_uri_message)
[docs] def load_to_temp_file( self, file_uri_src: Uri, delete: bool = False, should_create_symlinks_if_possible: bool = True, ) -> TemporaryFileWrapper: temp_file_handle = tempfile.NamedTemporaryFile(delete=delete) temp_file_path = LocalUri(str(temp_file_handle.name)) self.load_file( file_uri_src=file_uri_src, file_uri_dst=temp_file_path, should_create_symlinks_if_possible=should_create_symlinks_if_possible, ) return temp_file_handle
[docs] def count_assets(self, uri_prefix: Uri, suffix: Optional[str] = None) -> int: if isinstance(uri_prefix, GcsUri): return self.__gcs_utils.count_blobs_in_gcs_path( gcs_path=uri_prefix, suffix=suffix ) elif isinstance(uri_prefix, LocalUri): return count_files_with_uri_prefix(uri_prefix=uri_prefix, suffix=suffix) else: raise TypeError( f"Uri type not supported, got {uri_prefix} in type {type(uri_prefix)}" )
[docs] def does_uri_exist(self, uri: Union[str, Uri]) -> bool: """"" Check if a URI exists Args: uri (Union[str, Uri]): uri to check Returns: bool: True if URI exists, False otherwise """ "" _uri = UriFactory.create_uri(uri=uri) if isinstance(uri, str) else uri exists: bool if GcsUri.is_valid(uri=_uri, raise_exception=False): exists = self.__gcs_utils.does_gcs_file_exist(gcs_path=_uri) # type: ignore elif LocalUri.is_valid(uri=_uri, raise_exception=False): exists = does_path_exist(cast(LocalUri, _uri)) else: raise NotImplementedError(f"{self.__unsupported_uri_message} : {_uri}") return exists
[docs] def delete_files(self, uris: List[Uri]) -> None: """ Recursively delete files in the specified URIs. Args: uris (List[Uri]): URIs to delete Returns None """ for uri in uris: if isinstance(uri, LocalUri): remove_file_or_folder_if_exist(local_path=uri) elif isinstance(uri, GcsUri): self.__gcs_utils.delete_files_in_bucket_dir(gcs_path=uri) else: raise NotImplementedError( f"Cannot delete URI {uri.uri} of type {type(uri)}; {self.__unsupported_uri_message}" )
[docs] def list_children(self, uri: Uri, pattern: Optional[str] = None) -> Sequence[Uri]: """ List all children of the given URI. Args: uri (Uri): The URI to list children of. pattern (Optional[str]): Optional regex to match. If not provided then all children will be returned. Returns: List[Uri]: A list of URIs for the children of the given URI. """ if isinstance(uri, GcsUri): return self.__gcs_utils.list_uris_with_gcs_path_pattern( gcs_path=uri, pattern=pattern ) elif isinstance(uri, LocalUri): return list_at_path(local_path=uri, regex=pattern) else: raise NotImplementedError( f"Cannot list children of URI {uri.uri} of type {type(uri)}; {self.__unsupported_uri_message}" )