import tempfile
from dataclasses import dataclass
from typing import Optional
import tensorflow as tf
import torch
from tensorflow_transform.tf_metadata import schema_utils
from gigl.common import GcsUri, LocalUri, UriFactory
from gigl.common.logger import Logger
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.src.common.types.graph_data import (
    CondensedEdgeType,
    CondensedNodeType,
    EdgeType,
    EdgeUsageType,
    NodeType,
)
from gigl.src.common.utils.file_loader import FileLoader
from gigl.src.data_preprocessor.lib.transform.tf_value_encoder import TFValueEncoder
from gigl.src.data_preprocessor.lib.types import FeatureSpecDict, InstanceDict
from gigl.src.mocking.lib.constants import (
    get_example_task_edge_features_gcs_dir,
    get_example_task_edge_features_schema_gcs_path,
    get_example_task_node_features_gcs_dir,
    get_example_task_node_features_schema_gcs_path,
)
from gigl.src.mocking.lib.feature_handling import get_feature_field_name
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo
from snapchat.research.gbml import gbml_config_pb2, preprocessed_metadata_pb2
@dataclass
class _PreprocessMetadata:
    features_uri: GcsUri
    schema_uri: GcsUri
    feature_cols: list[str]
@dataclass
class _NodePreprocessMetadata(_PreprocessMetadata):
    id_col: str
    label_col: Optional[str] = None
@dataclass
class _EdgePreprocessMetadata(_PreprocessMetadata):
    src_id_col: str
    dst_id_col: str
class _InstanceDictToTFExample:
    """
    Uses a feature spec to process a raw instance dict (read from some tabular data) as a TFExample.
    """
    def __init__(self, feature_spec: FeatureSpecDict):
        self.feature_spec = feature_spec
    def process(self, element: InstanceDict) -> bytes:
        # Each row is a single instance dict from the original tabular input (BQ, GCS, etc.)
        example = dict()
        for key in self.feature_spec.keys():
            # prepare each value associated with a key that appears in the feature_spec.
            # only the instance dict keys the user specifies wanting in the feature_spec will pass through here
            value = element[key]
            if value is None:
                logger.debug(f"Found key {key} with missing value in sample {element}")
            example[key] = TFValueEncoder.encode_value_as_feature(
                value=value, dtype=self.feature_spec[key].dtype
            )
        example_proto = tf.train.Example(features=tf.train.Features(feature=example))
        serialized_proto = example_proto.SerializeToString()
        return serialized_proto
def _generate_preprocessed_node_tfrecord_data(
    data: MockedDatasetInfo,
    version: str,
    node_type: NodeType,
    num_node_features: int,
    node_features: torch.Tensor,
    node_labels: Optional[torch.Tensor],
) -> _NodePreprocessMetadata:
    feature_names: list[str] = [
        get_feature_field_name(n=i) for i in range(num_node_features)
    ]
    feature_spec_dict = {
        data.node_id_column_name: tf.io.FixedLenFeature(shape=[], dtype=tf.int64)
    }
    feature_spec_dict.update(
        {
            col: tf.io.FixedLenFeature(shape=[], dtype=tf.float32)
            for col in feature_names
        }
    )
    if node_labels is not None:
        feature_spec_dict.update(
            {
                data.node_label_column_name: tf.io.FixedLenFeature(
                    shape=[], dtype=tf.int64
                )
            }
        )
    id2tfe_encoder = _InstanceDictToTFExample(feature_spec=feature_spec_dict)
    tfrecords = []
    instance_dict_feats: dict[str, torch.Tensor]
    for node_id, node_feature_values in enumerate(node_features):
        instance_dict_feats = {data.node_id_column_name: torch.LongTensor([node_id])}
        instance_dict_feats.update(
            {
                feat_name: feat_value
                for feat_name, feat_value in zip(feature_names, node_feature_values)
            }
        )
        if node_labels is not None:
            instance_dict_feats.update(
                {data.node_label_column_name: node_labels[node_id]}
            )
        tfrecord_bytes = id2tfe_encoder.process(element=instance_dict_feats)
        tfrecords.append(tfrecord_bytes)
    # Write features to GCS.
    features_path = get_example_task_node_features_gcs_dir(
        task_name=data.name, version=version, node_type=node_type
    )
    with tf.io.TFRecordWriter(
        GcsUri.join(features_path, "data.tfrecord").uri
    ) as writer:
        for tfrecord in tfrecords:
            writer.write(tfrecord)
    logger.info(
        f"Wrote preprocessed node TFRecords for type {node_type} to prefix {features_path.uri}"
    )
    # Write schema to GCS.
    node_schema_uri = get_example_task_node_features_schema_gcs_path(
        task_name=data.name, version=version, node_type=node_type
    )
    node_schema = schema_utils.schema_from_feature_spec(feature_spec=feature_spec_dict)
    file_loader = FileLoader()
    temp_file_handle = tempfile.NamedTemporaryFile()
    with open(temp_file_handle.name, "w") as f:
        f.write(repr(node_schema))
    file_loader.load_file(
        file_uri_src=LocalUri(temp_file_handle.name),
        file_uri_dst=node_schema_uri,
    )
    logger.info(
        f"Wrote preprocessed node TFRecords schema for type {node_type} to prefix {node_schema_uri.uri}"
    )
    return _NodePreprocessMetadata(
        features_uri=features_path,
        schema_uri=node_schema_uri,
        feature_cols=feature_names,
        id_col=data.node_id_column_name,
        label_col=data.node_label_column_name if node_labels is not None else None,
    )
def _generate_preprocessed_edge_tfrecord_data(
    data: MockedDatasetInfo,
    version: str,
    edge_type: EdgeType,
    edge_index: torch.Tensor,
    num_edge_features: int,
    edge_features: Optional[torch.Tensor],
    edge_usage_type: EdgeUsageType,
) -> _EdgePreprocessMetadata:
    feature_names: list[str] = [
        get_feature_field_name(n=i) for i in range(num_edge_features)
    ]
    feature_spec_dict = {
        data.edge_src_column_name: tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
        data.edge_dst_column_name: tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
    }
    feature_spec_dict.update(
        {
            col: tf.io.FixedLenFeature(shape=[], dtype=tf.float32)
            for col in feature_names
        }
    )
    id2tfe_encoder = _InstanceDictToTFExample(feature_spec=feature_spec_dict)
    tfrecords = []
    for edge_id, (src_id, dst_id) in enumerate(zip(edge_index[0, :], edge_index[1, :])):
        instance_dict_feats = {
            data.edge_src_column_name: src_id,
            data.edge_dst_column_name: dst_id,
        }
        if edge_features is not None:
            edge_feature_values = edge_features[edge_id, :]
            instance_dict_feats.update(
                {
                    feat_name: feat_value
                    for feat_name, feat_value in zip(feature_names, edge_feature_values)
                }
            )
        tfrecord_bytes = id2tfe_encoder.process(element=instance_dict_feats)
        tfrecords.append(tfrecord_bytes)
    # Write features to GCS.
    features_path = get_example_task_edge_features_gcs_dir(
        task_name=data.name,
        version=version,
        edge_type=edge_type,
        edge_usage_type=edge_usage_type,
    )
    with tf.io.TFRecordWriter(
        GcsUri.join(features_path, "data.tfrecord").uri
    ) as writer:
        for tfrecord in tfrecords:
            writer.write(tfrecord)
    logger.info(
        f"Wrote preprocessed edge TFRecords for type {edge_type} to prefix {features_path.uri}"
    )
    # Write schema to GCS.
    edge_schema_uri = get_example_task_edge_features_schema_gcs_path(
        task_name=data.name,
        version=version,
        edge_type=edge_type,
        edge_usage_type=edge_usage_type,
    )
    edge_schema = schema_utils.schema_from_feature_spec(feature_spec=feature_spec_dict)
    file_loader = FileLoader()
    temp_file_handle = tempfile.NamedTemporaryFile()
    with open(temp_file_handle.name, "w") as f:
        f.write(repr(edge_schema))
    file_loader.load_file(
        file_uri_src=LocalUri(temp_file_handle.name),
        file_uri_dst=edge_schema_uri,
    )
    logger.info(
        f"Wrote preprocessed edge TFRecords schema for type {edge_type} to {edge_schema_uri.uri}"
    )
    return _EdgePreprocessMetadata(
        features_uri=features_path,
        schema_uri=edge_schema_uri,
        feature_cols=feature_names,
        src_id_col=data.edge_src_column_name,
        dst_id_col=data.edge_dst_column_name,
    )
[docs]
def generate_preprocessed_tfrecord_data(
    mocked_dataset_info: MockedDatasetInfo,
    version: str,
    gbml_config_pb: gbml_config_pb2.GbmlConfig,
):
    graph_metadata_pb_wrapper = mocked_dataset_info.graph_metadata_pb_wrapper
    num_features_by_node_type = mocked_dataset_info.num_node_features
    node_features_by_node_type = mocked_dataset_info.node_feats
    node_labels_by_node_type = mocked_dataset_info.node_labels
    node_types = mocked_dataset_info.node_types
    condensed_node_type_to_preprocessed_metadata: dict[
        CondensedNodeType,
        preprocessed_metadata_pb2.PreprocessedMetadata.NodeMetadataOutput,
    ] = dict()
    for node_type in node_types:
        condensed_node_type = (
            graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[node_type]
        )
        num_features = num_features_by_node_type[node_type]
        node_features = node_features_by_node_type[node_type]
        if not num_features:
            continue
        node_labels = (
            node_labels_by_node_type[node_type]
            if node_labels_by_node_type is not None
            else None
        )
        node_preprocess_metadata = _generate_preprocessed_node_tfrecord_data(
            data=mocked_dataset_info,
            version=version,
            node_type=node_type,
            num_node_features=num_features,
            node_features=node_features,
            node_labels=node_labels,
        )
        condensed_node_type_to_preprocessed_metadata[
            condensed_node_type
        ] = preprocessed_metadata_pb2.PreprocessedMetadata.NodeMetadataOutput(
            node_id_key=node_preprocess_metadata.id_col,
            feature_keys=node_preprocess_metadata.feature_cols,
            label_keys=[node_preprocess_metadata.label_col] if node_preprocess_metadata.label_col is not None else None,  # type: ignore
            tfrecord_uri_prefix=node_preprocess_metadata.features_uri.uri,
            schema_uri=node_preprocess_metadata.schema_uri.uri,
            feature_dim=num_features,
        )
    num_features_by_edge_type = mocked_dataset_info.num_edge_features
    edge_features_by_edge_type = mocked_dataset_info.edge_feats
    edge_index_by_edge_type = mocked_dataset_info.edge_index
    edge_types = mocked_dataset_info.edge_types
    condensed_edge_type_to_preprocessed_metadata: dict[
        CondensedEdgeType,
        preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataOutput,
    ] = dict()
    for edge_type in edge_types:
        condensed_edge_type = (
            graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[edge_type]
        )
        num_features = num_features_by_edge_type[edge_type]
        edge_features = (
            edge_features_by_edge_type[edge_type]
            if edge_features_by_edge_type is not None
            else None
        )
        edge_index = edge_index_by_edge_type[edge_type]
        edge_preprocess_metadata = _generate_preprocessed_edge_tfrecord_data(
            data=mocked_dataset_info,
            version=version,
            edge_type=edge_type,
            edge_index=edge_index,
            num_edge_features=num_features,
            edge_features=edge_features,
            edge_usage_type=EdgeUsageType.MAIN,
        )
        main_edge_metadata_info_pb = (
            preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataInfo(
                feature_keys=edge_preprocess_metadata.feature_cols,
                tfrecord_uri_prefix=edge_preprocess_metadata.features_uri.uri,
                schema_uri=edge_preprocess_metadata.schema_uri.uri,
                feature_dim=num_features,
            )
        )
        if (
            mocked_dataset_info.user_defined_edge_index
            and edge_type in mocked_dataset_info.user_defined_edge_index
        ):
            edge_preprocess_metadata_pb_dict = {}
            for (
                user_def_label,
                user_def_edge_index,
            ) in mocked_dataset_info.user_defined_edge_index[edge_type].items():
                num_edge_feats = mocked_dataset_info.num_user_def_edge_features[
                    edge_type
                ][user_def_label]
                user_defined_edge_feats = (
                    mocked_dataset_info.user_defined_edge_feats[edge_type][
                        user_def_label
                    ]
                    if mocked_dataset_info.user_defined_edge_feats
                    and edge_type in mocked_dataset_info.user_defined_edge_feats
                    else None
                )
                user_def_edge_preprocess_metadata = (
                    _generate_preprocessed_edge_tfrecord_data(
                        data=mocked_dataset_info,
                        version=version,
                        edge_type=edge_type,
                        edge_index=user_def_edge_index,
                        num_edge_features=num_edge_feats,
                        edge_features=user_defined_edge_feats,
                        edge_usage_type=user_def_label,
                    )
                )
                user_def_edge_metadata_info_pb = preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataInfo(
                    feature_keys=user_def_edge_preprocess_metadata.feature_cols,
                    tfrecord_uri_prefix=user_def_edge_preprocess_metadata.features_uri.uri,
                    schema_uri=user_def_edge_preprocess_metadata.schema_uri.uri,
                    feature_dim=num_edge_feats,
                )
                edge_preprocess_metadata_pb_dict[
                    user_def_label
                ] = user_def_edge_metadata_info_pb
            condensed_edge_type_to_preprocessed_metadata[
                condensed_edge_type
            ] = preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataOutput(
                src_node_id_key=edge_preprocess_metadata.src_id_col,
                dst_node_id_key=edge_preprocess_metadata.dst_id_col,
                main_edge_info=main_edge_metadata_info_pb,
                positive_edge_info=edge_preprocess_metadata_pb_dict.get(
                    EdgeUsageType.POSITIVE, None
                ),
                negative_edge_info=edge_preprocess_metadata_pb_dict.get(
                    EdgeUsageType.NEGATIVE, None
                ),
            )
        else:
            condensed_edge_type_to_preprocessed_metadata[
                condensed_edge_type
            ] = preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataOutput(
                src_node_id_key=edge_preprocess_metadata.src_id_col,
                dst_node_id_key=edge_preprocess_metadata.dst_id_col,
                main_edge_info=main_edge_metadata_info_pb,
            )
    # Assemble Preprocessed Metadata pb and write out.
    preprocessed_metadata_pb = preprocessed_metadata_pb2.PreprocessedMetadata()
    for (
        condensed_node_type,
        node_metadata_output,
    ) in condensed_node_type_to_preprocessed_metadata.items():
        preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[
            condensed_node_type
        ].CopyFrom(node_metadata_output)
    for (
        condensed_edge_type,
        edge_metadata_output,
    ) in condensed_edge_type_to_preprocessed_metadata.items():
        preprocessed_metadata_pb.condensed_edge_type_to_preprocessed_metadata[
            condensed_edge_type
        ].CopyFrom(edge_metadata_output)
    preprocessed_metadata_uri = UriFactory.create_uri(
        gbml_config_pb.shared_config.preprocessed_metadata_uri
    )
    proto_utils = ProtoUtils()
    proto_utils.write_proto_to_yaml(
        proto=preprocessed_metadata_pb, uri=preprocessed_metadata_uri
    )
    logger.info(
        f"Wrote preprocessed metadata proto to {preprocessed_metadata_uri.uri}."
    )