from typing import Optional, Tuple
import apache_beam as beam
import tensorflow as tf
from apache_beam.pvalue import PCollection
from gigl.common import GcsUri, Uri
from gigl.common.logger import Logger
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.graph_data import NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.dataflow import init_beam_pipeline_options
from gigl.src.data_preprocessor.lib.enumerate.queries import (
    DEFAULT_ENUMERATED_NODE_ID_FIELD,
    DEFAULT_ORIGINAL_NODE_ID_FIELD,
)
from gigl.src.inference.v1.lib.base_inference_blueprint import (
    EMBEDDING_TAGGED_OUTPUT_KEY,
    PREDICTION_TAGGED_OUTPUT_KEY,
    BaseInferenceBlueprint,
)
from gigl.src.inference.v1.lib.inference_output_schema import (
    DEFAULT_EMBEDDING_FIELD,
    DEFAULT_NODE_ID_FIELD,
    DEFAULT_PREDICTION_FIELD,
)
from gigl.src.inference.v1.lib.transforms.batch_generator import BatchProcessorDoFn
from snapchat.research.gbml.gigl_resource_config_pb2 import DataflowResourceConfig
# TODO(svij-sc) adopt dynamic batching
[docs]
DEFAULT_BATCH_SIZE = 3000 
[docs]
class UnenumerateAssets(beam.PTransform):
    def __init__(self, tagged_output_key: str):
        if tagged_output_key == PREDICTION_TAGGED_OUTPUT_KEY:
            self.field = DEFAULT_PREDICTION_FIELD
        elif tagged_output_key == EMBEDDING_TAGGED_OUTPUT_KEY:
            self.field = DEFAULT_EMBEDDING_FIELD
        else:
            raise NotImplementedError(
                "Only embedding or prediction outputs are supported"
            )
[docs]
    def expand(self, pcolls: Tuple[PCollection, PCollection]) -> PCollection:
        """
        Performs unenumeration on two PCollections through a join between the two collections.
        The first PCollection should contain the DEFAULT_NODE_ID_FIELD and either DEFAULT_PREDICTION_FIELD or DEFAULT_EMBEDDING_FIELD columns.
        The second PCollection should contain the DEFAULT_ENUMERATED_NODE_ID_FIELD and DEFAULT_ORIGINAL_NODE_ID_FIELD columns.
        The two pcollections will be joined by the values in the DEFAULT_NODE_ID_FIELD and DEFAULT_ENUMERATED_NODE_ID_FIELD columns.
        """
        output, mapping = pcolls
        enumerated_assets = output | "Format outputs" >> beam.Map(
            lambda row: (
                row[DEFAULT_NODE_ID_FIELD],
                row[self.field],
            )
        )
        unenumerated_assets = (
            {"enumerated_assets": enumerated_assets, "mapping": mapping}
            | "Perform join" >> beam.CoGroupByKey()
            # CoGroupByKey joins by the first element of each tuple, in this case mapping.int_id and predictions.node_id
            | "Extract prediction and Format"
            >> beam.Map(
                lambda kv: {
                    DEFAULT_NODE_ID_FIELD: kv[1]["mapping"][0],
                    self.field: kv[1]["enumerated_assets"][0],
                }
            )
        )
        return unenumerated_assets 
 
[docs]
def get_inferencer_pipeline_component_for_single_node_type(
    gbml_config_pb_wrapper: GbmlConfigPbWrapper,
    inference_blueprint: BaseInferenceBlueprint,
    applied_task_identifier: AppliedTaskIdentifier,
    custom_worker_image_uri: Optional[str],
    node_type: NodeType,
    uri_prefix_list: list[Uri],
    temp_predictions_gcs_path: Optional[GcsUri],
    temp_embeddings_gcs_path: Optional[GcsUri],
) -> beam.Pipeline:
    """
    Gets the beam pipeline for running the inference dataflow job
    Args:
        gbml_config_pb_wrapper (GbmlConfigPbWrapper): GBML config wrapper for this inference run
        inference_blueprint (BaseInferenceBlueprint): Blueprint for running and saving inference for GBML pipelines
        applied_task_identifier (AppliedTaskIdentifier): Identifier for the GiGL job
        custom_worker_image_uri (Optional[str]): Uri to custom worker image
        node_type (NodeType): Node type being inferred
        uri_prefix_list (list[Uri]): List of prefixes for running inference for given node type
        temp_predictions_gcs_path (Optional[GcsUri]): Gcs uri for writing temp predictions
        temp_embeddings_gcs_path (Optional[GcsUri]): Gcs uri for writing temp embeddings
    Returns:
        pipeline (beam.Pipeline): Dataflow pipeline for running inference
    """
    # Launching one beam pipeline per node type
    inferencer_config = get_resource_config().inferencer_config
    assert isinstance(
        inferencer_config, DataflowResourceConfig
    ), f"Only Dataflow is supported for v1 inference, got: {type(inferencer_config)}"
    condensed_node_type_to_preprocessed_metadata = (
        gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata
    )
    batch_size = (
        gbml_config_pb_wrapper.inferencer_config.inference_batch_size
        or DEFAULT_BATCH_SIZE
    )
    options = init_beam_pipeline_options(
        applied_task_identifier=applied_task_identifier,
        job_name_suffix=f"{node_type}_inference",
        component=GiGLComponents.Inferencer,
        num_workers=inferencer_config.num_workers,
        max_num_workers=inferencer_config.max_num_workers,
        machine_type=inferencer_config.machine_type,
        disk_size_gb=inferencer_config.disk_size_gb,
        resource_config=get_resource_config().get_resource_config_uri,
        custom_worker_image_uri=custom_worker_image_uri,
    )
    condensed_node_type = gbml_config_pb_wrapper.graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[
        NodeType(node_type)
    ]
    mapping_table = condensed_node_type_to_preprocessed_metadata[
        condensed_node_type
    ].enumerated_node_ids_bq_table
    should_run_unenumeration = bool(mapping_table)
    pipeline = beam.Pipeline(options=options)
    record_pcolls = []
    for uri_prefix in uri_prefix_list:
        tfrecord_glob_str = f"{uri_prefix.uri}*.tfrecord"
        files = tf.io.gfile.glob(tfrecord_glob_str)
        if not files:
            logger.warning(
                f"Found no TFRecord files at {uri_prefix.uri} for node type {node_type}"
            )
            continue
        pcol = (
            pipeline
            | f"Read TFRecords from uri prefix {uri_prefix}"
            >> beam.io.ReadFromTFRecord(
                file_pattern=tfrecord_glob_str,
                coder=inference_blueprint.get_tf_record_coder(),
            )
        )
        record_pcolls.append(pcol)
    outputs = (
        record_pcolls
        | f"Flatten read pcollections" >> beam.Flatten()
        | "Batch Elements"
        >> beam.BatchElements(
            min_batch_size=batch_size,
            max_batch_size=batch_size,
            target_batch_duration_secs_including_fixed_cost=1,
        )
        | "Generate Batches"
        >> beam.ParDo(
            BatchProcessorDoFn(
                batch_generator_fn=inference_blueprint.get_batch_generator_fn(),
            )
        )
        | "Inference"
        >> beam.ParDo(inference_blueprint.get_inferer()).with_outputs(
            PREDICTION_TAGGED_OUTPUT_KEY, EMBEDDING_TAGGED_OUTPUT_KEY
        )
    )
    if should_run_unenumeration:
        mapping = (
            pipeline
            | "Read mapping"
            >> beam.io.gcp.bigquery.ReadFromBigQuery(table=mapping_table)
            | "Map mapping field for node type"
            >> beam.Map(
                lambda row: (
                    row[DEFAULT_ENUMERATED_NODE_ID_FIELD],
                    row[DEFAULT_ORIGINAL_NODE_ID_FIELD],
                )
            )
        )
        if temp_predictions_gcs_path is not None:
            predictions = (
                outputs[PREDICTION_TAGGED_OUTPUT_KEY],
                mapping,
            ) | "Unenumerate Predictions" >> UnenumerateAssets(
                tagged_output_key=PREDICTION_TAGGED_OUTPUT_KEY
            )
        if temp_embeddings_gcs_path is not None:
            embeddings = (
                outputs[EMBEDDING_TAGGED_OUTPUT_KEY],
                mapping,
            ) | "Unenumerate Embeddings" >> UnenumerateAssets(
                tagged_output_key=EMBEDDING_TAGGED_OUTPUT_KEY
            )
    else:
        logger.info(
            f"Skipping un-enumeration for node type {node_type} since no mapping table exists"
        )
        predictions = outputs[PREDICTION_TAGGED_OUTPUT_KEY]
        embeddings = outputs[EMBEDDING_TAGGED_OUTPUT_KEY]
    if temp_predictions_gcs_path is not None:
        logger.info(
            f"Writing node type {node_type} temp predictions to gcs path {temp_predictions_gcs_path.uri}"
        )
        (
            predictions
            | "Write temp predictions to gcs"
            >> beam.io.WriteToText(
                file_path_prefix=temp_predictions_gcs_path.uri,
                file_name_suffix=".json",
            )
        )
    if temp_embeddings_gcs_path is not None:
        logger.info(
            f"Writing node type {node_type} temp embeddings to gcs path {temp_embeddings_gcs_path.uri}"
        )
        (
            embeddings
            | "Write temp embeddings to gcs"
            >> beam.io.WriteToText(
                file_path_prefix=temp_embeddings_gcs_path.uri,
                file_name_suffix=".json",
            )
        )
    return pipeline