from typing import Optional, Union
from gigl.common import GcsUri
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.dataset_split import DatasetSplit
from gigl.src.common.types.features import FeatureTypes
from gigl.src.common.types.graph_data import EdgeType, NodeType
_CONFIG_POPULATOR = "config_populator"
_DATA_PREPROCESSOR_PREFIX = "data_preprocess"
_SPLIT_GENERATOR_PREFIX = "split_generator"
_SUBGRAPH_SAMPLER_PREFIX = "subgraph_sampler"
_TRAINER_PREFIX = "trainer"
_INFERENCER_PREFIX = "inferencer"
_POST_PROCESSOR_PREFIX = "post_processor"
[docs]
def get_applied_task_temp_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the temp_assets bucket for a given gigl job.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the temp assets bucket.
    """
    return GcsUri.join(
        get_resource_config().temp_assets_bucket_path, applied_task_identifier
    ) 
[docs]
def get_applied_task_temp_regional_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the temp regional assets for a given gigl job.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the temp regional assets.
    """
    return GcsUri.join(
        get_resource_config().temp_assets_regional_bucket_path, applied_task_identifier
    ) 
[docs]
def get_applied_task_perm_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the perm assets bucket for a given gigl job.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the perm assets bucket.
    """
    return GcsUri.join(
        get_resource_config().perm_assets_bucket_path, applied_task_identifier
    ) 
[docs]
def get_data_preprocessor_assets_temp_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for temporary data preprocessor assets for a given gigl job.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for temporary data preprocessor assets.
    """
    return GcsUri.join(
        get_applied_task_temp_regional_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        _DATA_PREPROCESSOR_PREFIX,
    ) 
[docs]
def get_data_preprocessor_assets_perm_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the data preprocessor perm assets for a given gigl job.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the data preprocessor perm assets.
    """
    return GcsUri.join(
        get_applied_task_perm_gcs_path(applied_task_identifier=applied_task_identifier),
        _DATA_PREPROCESSOR_PREFIX,
    ) 
[docs]
def get_data_preprocessor_staging_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the staging directory of the data preprocessor assets for a given gigl job.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the staging directory of the data preprocessor assets.
    """
    return GcsUri.join(
        get_data_preprocessor_assets_temp_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        "staging",
    ) 
[docs]
def get_split_generator_assets_temp_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the temporary GCS path for Split Generator assets.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The temporary GCS path for Split Generator assets.
    """
    return GcsUri.join(
        get_applied_task_temp_gcs_path(applied_task_identifier=applied_task_identifier),
        _SPLIT_GENERATOR_PREFIX,
    ) 
[docs]
def get_dataflow_staging_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
    job_name: str,
) -> GcsUri:
    """
    Returns the GCS path for the staging directory used for Dataflow Jobs.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
        job_name (str): The name of the Dataflow job.
    Returns:
        GcsUri: The GCS path for the staging directory used for Dataflow Jobs.
    """
    return GcsUri.join(
        get_applied_task_temp_gcs_path(
            applied_task_identifier=applied_task_identifier,
        ),
        job_name,
        "staging",
    ) 
[docs]
def get_dataflow_temp_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
    job_name: str,
) -> GcsUri:
    """
    Returns the GCS path for the "tmp" directory used for Dataflow Jobs.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
        job_name (str): The name of the Dataflow job.
    Returns:
        GcsUri: The GCS path for the "tmp" directory used for Dataflow Jobs.
    """
    return GcsUri.join(
        get_applied_task_temp_gcs_path(
            applied_task_identifier=applied_task_identifier,
        ),
        job_name,
        "tmp",
    ) 
[docs]
def get_split_dataset_output_gcs_file_prefix(
    applied_task_identifier: AppliedTaskIdentifier, dataset_split: DatasetSplit
) -> GcsUri:
    """
    Returns the GCS file prefix for the samples output by Split Generator.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
        dataset_split (DatasetSplit): The dataset split.
    Returns:
        GcsUri: The GCS file prefix for the samples output by Split Generator.
    """
    return GcsUri.join(
        get_split_generator_assets_temp_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        dataset_split.value,
        "samples/",
    ) 
[docs]
def get_subgraph_sampler_root_dir(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path which Subgraph Sampler uses to store temp assets.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path which Subgraph Sampler uses to store temp assets.
    """
    return GcsUri.join(
        get_applied_task_temp_gcs_path(applied_task_identifier=applied_task_identifier),
        "subgraph_sampler",
    ) 
[docs]
def get_subgraph_sampler_supervised_node_classification_task_dir(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path which Subgraph Sampler uses to store temp assets for supervised node classification.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path which Subgraph Sampler uses to store temp assets for supervised node classification.
    """
    return GcsUri.join(
        get_subgraph_sampler_root_dir(applied_task_identifier=applied_task_identifier),
        "supervised_node_classification",
    ) 
[docs]
def get_subgraph_sampler_node_anchor_based_link_prediction_task_dir(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path which Subgraph Sampler uses to store temp assets for node anchor based link prediction.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path which Subgraph Sampler uses to store temp assets for node anchor based link prediction.
    """
    return GcsUri.join(
        get_subgraph_sampler_root_dir(applied_task_identifier=applied_task_identifier),
        "node_anchor_based_link_prediction",
    ) 
[docs]
def get_subgraph_sampler_supervised_link_based_task_dir(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path which Subgraph Sampler uses to store temp assets for supervised link based tasks.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path which Subgraph Sampler uses to store temp assets for supervised link based tasks.
    """
    return GcsUri.join(
        get_subgraph_sampler_root_dir(applied_task_identifier=applied_task_identifier),
        "supervised_link_based",
    ) 
[docs]
def get_subgraph_sampler_node_neighborhood_samples_dir(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path which Subgraph Sampler uses to store node neighborhood samples.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path which Subgraph Sampler uses to store node neighborhood samples.
    """
    return GcsUri.join(
        get_subgraph_sampler_root_dir(applied_task_identifier=applied_task_identifier),
        "node_neighborhood_samples",
    ) 
[docs]
def get_subgraph_sampler_supervised_node_classification_labeled_samples_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS file prefix for labeled samples output by Subgraph Sampler for supervised node classification.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS file prefix for labeled samples output by Subgraph Sampler for supervised node classification.
    """
    return GcsUri.join(
        get_subgraph_sampler_supervised_node_classification_task_dir(
            applied_task_identifier=applied_task_identifier
        ),
        "labeled",
        "samples/",
    ) 
[docs]
def get_subgraph_sampler_supervised_node_classification_unlabeled_samples_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS file prefix for unlabeled samples output by Subgraph Sampler for supervised node classification.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS file prefix for unlabeled samples output by Subgraph Sampler for supervised node classification.
    """
    return GcsUri.join(
        get_subgraph_sampler_supervised_node_classification_task_dir(
            applied_task_identifier=applied_task_identifier
        ),
        "unlabeled",
        "samples/",
    ) 
[docs]
def get_subgraph_sampler_supervised_link_based_task_labeled_samples_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS file prefix for labeled samples output by Subgraph Sampler for supervised link based tasks.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS file prefix for labeled samples output by Subgraph Sampler for supervised link based tasks.
    """
    return GcsUri.join(
        get_subgraph_sampler_supervised_link_based_task_dir(
            applied_task_identifier=applied_task_identifier
        ),
        "labeled",
        "samples/",
    ) 
[docs]
def get_subgraph_sampler_supervised_link_based_task_unlabeled_samples_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS file prefix for unlabeled samples output by Subgraph Sampler for supervised link based tasks.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS file prefix for unlabeled samples output by Subgraph Sampler for supervised link based tasks.
    """
    return GcsUri.join(
        get_subgraph_sampler_supervised_link_based_task_dir(
            applied_task_identifier=applied_task_identifier
        ),
        "unlabeled",
        "samples/",
    ) 
[docs]
def get_subgraph_sampler_node_neighborhood_samples_path_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS file prefix for node neighborhood samples output by Subgraph Sampler.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS file prefix for node neighborhood samples output by Subgraph Sampler.
    """
    return GcsUri.join(
        get_subgraph_sampler_node_neighborhood_samples_dir(
            applied_task_identifier=applied_task_identifier
        ),
        "samples/",
    ) 
[docs]
def get_subgraph_sampler_node_anchor_based_link_prediction_samples_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS file prefix for samples output by Subgraph Sampler for node anchor based link prediction.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS file prefix for samples output by Subgraph Sampler for node anchor based link prediction.
    """
    return GcsUri.join(
        get_subgraph_sampler_node_anchor_based_link_prediction_task_dir(
            applied_task_identifier=applied_task_identifier
        ),
        "node_anchor_based_link_prediction_samples",
        "samples/",
    ) 
[docs]
def get_subgraph_sampler_node_anchor_based_link_prediction_random_negatives_samples_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
    node_type: NodeType,
) -> GcsUri:
    """
    Returns the GCS file prefix for random negative samples output by Subgraph Sampler for node anchor based link prediction.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS file prefix for random negative samples output by Subgraph Sampler for node anchor based link prediction.
    """
    return GcsUri.join(
        get_subgraph_sampler_node_anchor_based_link_prediction_task_dir(
            applied_task_identifier=applied_task_identifier
        ),
        "random_negative_rooted_neighborhood_samples",
        node_type,
        "samples/",
    ) 
[docs]
def get_split_dataset_main_samples_gcs_file_prefix(
    applied_task_identifier: AppliedTaskIdentifier, dataset_split: DatasetSplit
) -> GcsUri:
    """
    Returns the GCS file prefix for the main samples output by Split Generator.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
        dataset_split (DatasetSplit): The dataset split.
    Returns:
        GcsUri: The GCS file prefix for the main samples output by Split Generator.
    """
    return GcsUri.join(
        get_split_generator_assets_temp_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        dataset_split.value,
        "main_samples",
        "samples/",
    ) 
[docs]
def get_split_dataset_random_negatives_gcs_file_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
    node_type: NodeType,
    dataset_split: DatasetSplit,
) -> GcsUri:
    """
    Returns the GCS file prefix for the random negative samples output by Split Generator.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
        dataset_split (DatasetSplit): The dataset split.
    Returns:
        GcsUri: The GCS file prefix for the random negative samples output by Split Generator.
    """
    return GcsUri.join(
        get_split_generator_assets_temp_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        dataset_split.value,
        "random_negatives",
        node_type,
        "neighborhoods/",
    ) 
[docs]
def get_config_populator_assets_perm_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the config populator perm assets for a given gigl job (Used to write Frozen GBML Config).
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the config populator perm assets.
    """
    return GcsUri.join(
        get_applied_task_perm_gcs_path(applied_task_identifier=applied_task_identifier),
        _CONFIG_POPULATOR,
    ) 
[docs]
def get_frozen_gbml_config_proto_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the frozen GBML config proto file.
    See: proto/snapchat/research/gbml/gbml_config.proto for more details.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the frozen GBML config proto file.
    """
    return GcsUri.join(
        get_config_populator_assets_perm_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        "frozen_gbml_config.yaml",
    ) 
[docs]
def get_trainer_asset_dir_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for perm assets written by the Trainer (e.g. trained models, eval metrics, etc.)
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for perm assets written by the Trainer.
    """
    return GcsUri.join(
        get_applied_task_perm_gcs_path(applied_task_identifier=applied_task_identifier),
        _TRAINER_PREFIX,
    ) 
[docs]
def get_trained_models_dir_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the trained models directory.
    """
    return GcsUri.join(
        get_trainer_asset_dir_gcs_path(applied_task_identifier=applied_task_identifier),
        "models",
    ) 
[docs]
def get_trained_model_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the trained model output by the Trainer (model.pt)
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the trained model output by the Trainer.
    """
    return GcsUri.join(
        get_trained_models_dir_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        "model.pt",
    ) 
[docs]
def get_trained_scripted_model_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the scripted model output by the Trainer (scripted_model.pt)
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the scripted model output by the Trainer.
    """
    return GcsUri.join(
        get_trained_models_dir_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        "scripted_model.pt",
    ) 
[docs]
def get_trained_model_eval_metrics_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the eval metrics output by the Trainer (eval_metrics.json)
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the eval metrics output by the Trainer.
    """
    return GcsUri.join(
        get_trained_models_dir_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        "trainer_eval_metrics.json",
    ) 
[docs]
def get_tensorboard_logs_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path that is used to store tensorboard logs.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path that is used to store tensorboard logs.
    """
    return GcsUri.join(
        get_trainer_asset_dir_gcs_path(applied_task_identifier=applied_task_identifier),
        "tensorboard_logs/",
    ) 
[docs]
def get_inferencer_asset_dir_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for perm assets written by the Inferencer (e.g. embeddings, predictions, etc.)
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for perm assets written by the Inferencer.
    """
    return GcsUri.join(
        get_applied_task_perm_gcs_path(applied_task_identifier=applied_task_identifier),
        _INFERENCER_PREFIX,
    ) 
[docs]
def get_inferencer_embeddings_gcs_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS directory for embeddings output by the Inferencer.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS directory for embeddings output by the Inferencer.
    """
    return GcsUri.join(
        get_inferencer_asset_dir_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        "embeddings/",
    ) 
[docs]
def get_inferencer_predictions_gcs_prefix(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS directory for predictions output by the Inferencer.
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS directory for predictions output by the Inferencer.
    """
    return GcsUri.join(
        get_inferencer_asset_dir_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        "predictions/",
    ) 
[docs]
def get_post_processor_asset_dir_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for perm assets written by the Post Processor (e.g. eval metrics, etc.)
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for perm assets written by the Post Processor.
    """
    return GcsUri.join(
        get_applied_task_perm_gcs_path(applied_task_identifier=applied_task_identifier),
        _POST_PROCESSOR_PREFIX,
    ) 
[docs]
def get_post_processor_metrics_gcs_path(
    applied_task_identifier: AppliedTaskIdentifier,
) -> GcsUri:
    """
    Returns the GCS path for the eval metrics output by the Post Processor (post_processor_metrics.json)
    Args:
        applied_task_identifier (AppliedTaskIdentifier): The job name.
    Returns:
        GcsUri: The GCS path for the eval metrics output by the Post Processor.
    """
    return GcsUri.join(
        get_post_processor_asset_dir_gcs_path(
            applied_task_identifier=applied_task_identifier
        ),
        "post_processor_metrics.json",
    )