Source code for gigl.src.data_preprocessor.lib.transform.utils
from typing import Any, Callable, Iterable, Optional, Tuple, Union
import apache_beam as beam
import pyarrow as pa
import tensorflow_data_validation as tfdv
import tensorflow_transform
import tfx_bsl
from apache_beam.pvalue import PBegin, PCollection, PDone
from tensorflow_metadata.proto.v0 import schema_pb2, statistics_pb2
from tensorflow_transform import beam as tft_beam
from tensorflow_transform.tf_metadata import schema_utils
from tfx_bsl.tfxio.record_based_tfxio import RecordBasedTFXIO
from gigl.common import GcsUri, LocalUri, Uri
from gigl.common.beam.better_tfrecordio import BetterWriteToTFRecord  # type: ignore
from gigl.common.logger import Logger
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.utils.dataflow import init_beam_pipeline_options
from gigl.src.common.utils.time import current_formatted_datetime
from gigl.src.data_preprocessor.lib.ingest.reference import (
    DataReference,
    EdgeDataReference,
    NodeDataReference,
)
from gigl.src.data_preprocessor.lib.transform.tf_value_encoder import TFValueEncoder
from gigl.src.data_preprocessor.lib.transform.transformed_features_info import (
    TransformedFeaturesInfo,
)
from gigl.src.data_preprocessor.lib.types import (
    EdgeDataPreprocessingSpec,
    FeatureSpecDict,
    InstanceDict,
    NodeDataPreprocessingSpec,
    TFTensorDict,
)
[docs]
class InstanceDictToTFExample(beam.DoFn):
    """
    Uses a feature spec to process a raw instance dict (read from some tabular data) as a TFExample.  These
    instance dict inputs could allow us to read tabular input data from BQ, GSC or anything else. As long as we
    have a way of yielding instance dicts and parsing them with a feature spec, we should be able to
    transform this data into TFRecords during ingestion, which allows for more efficient operations in TFT.
    See https://www.tensorflow.org/tfx/transform/get_started#the_tfxio_format.
    """
    def __init__(
        self,
        feature_spec: FeatureSpecDict,
        schema: schema_pb2.Schema,
    ):
        self._coder: Optional[tensorflow_transform.coders.ExampleProtoCoder] = None
[docs]
    def process(self, element: InstanceDict) -> Iterable[bytes]:
        # This coder is sensitive to environment (e.g., proto library version), and thus
        # it is recommended to instantiate the coder at pipeline execution time (i.e.,
        # in process function) instead of at pipeline construction time (i.e., in __init__)
        if not self._coder:
            self._coder = tensorflow_transform.coders.ExampleProtoCoder(self.schema)
        # Each element is a single row from the original tabular input (BQ, GCS, etc.)
        # Only features in the user specified feature_spec are extracted from element.
        # Imputation is applied when feature value is NULL.
        parsed_and_imputed_element = {
            feature_name: (
                element[feature_name]
                # If feature_name does not exist as a column in the original table, a
                # KeyError should raise to warn the user. Therefore, we do not use
                # element.get() here.
                if element[feature_name] is not None
                else TFValueEncoder.get_value_to_impute(dtype=spec.dtype)
            )
            for feature_name, spec in self.feature_spec.items()
        }
        yield self._coder.encode(parsed_and_imputed_element)
[docs]
class IngestRawFeatures(beam.PTransform):
    # TODO: investigate whether convert to TFXIO is adding overhead instead of speeding things up.
    def __init__(
        self,
        data_reference: DataReference,
        feature_spec: FeatureSpecDict,
        schema: schema_pb2.Schema,
        beam_record_tfxio: RecordBasedTFXIO,
    ):
[docs]
    def expand(self, pbegin: PBegin) -> PCollection[pa.RecordBatch]:
        if not isinstance(pbegin, PBegin):
            raise TypeError(
                f"Input to {IngestRawFeatures.__name__} transform "
                f"must be a PBegin but found {pbegin})"
            )
        return (
            pbegin
            | "Parse raw tabular features into instance dicts."
            >> self.data_reference.yield_instance_dict_ptransform()
            | "Serialize instance dicts to transformed TFExamples"
            >> beam.ParDo(
                InstanceDictToTFExample(
                    feature_spec=self.feature_spec, schema=self.schema
                )
            )
            | "Transformed TFExamples to RecordBatches with TFXIO"
            >> self.beam_record_tfxio.BeamSource()
        )
[docs]
class GenerateAndVisualizeStats(beam.PTransform):
    def __init__(self, facets_report_uri: GcsUri, stats_output_uri: GcsUri):
[docs]
    def expand(
        self, features: PCollection[pa.RecordBatch]
    ) -> PCollection[statistics_pb2.DatasetFeatureStatisticsList]:
        stats = features | "Generate TFDV statistics" >> tfdv.GenerateStatistics()
        _ = (
            stats
            | "Generate stats visualization"
            >> beam.Map(tfdv.utils.display_util.get_statistics_html)
            | "Write stats Facets report HTML"
            >> beam.io.WriteToText(
                self.facets_report_uri.uri, num_shards=1, shard_name_template=""
            )
        )
        _ = (
            stats
            | "Write TFDV stats output TFRecord"
            >> tfdv.WriteStatisticsToTFRecord(self.stats_output_uri.uri)
        )
        return stats
[docs]
class ReadExistingTFTransformFn(beam.PTransform):
    def __init__(self, tf_transform_directory: Uri):
        assert isinstance(tf_transform_directory, (GcsUri, LocalUri)), (
            f"tf_transform_directory must be a {GcsUri.__name__} or {LocalUri.__name__}, ",
            f"but found {tf_transform_directory.__class__.__name__}",
        )
[docs]
    def expand(self, pbegin: PBegin) -> PCollection[Any]:
        if not isinstance(pbegin, PBegin):
            raise TypeError(
                f"Input to {ReadExistingTFTransformFn.__name__} transform "
                f"must be a PBegin but found {pbegin})"
            )
        return pbegin | "Read existing TransformFn" >> tft_beam.ReadTransformFn(
            path=self.tf_transform_directory.uri
        )
[docs]
class AnalyzeAndBuildTFTransformFn(beam.PTransform):
    def __init__(
        self,
        tensor_adapter_config: tfx_bsl.tfxio.tensor_adapter.TensorAdapterConfig,
        preprocessing_fn: Callable[[TFTensorDict], TFTensorDict],
    ):
[docs]
    def expand(self, features: PCollection[pa.RecordBatch]) -> PCollection[Any]:
        return (
            features,
            self.tensor_adapter_config,
        ) | "Analyze raw features dataset" >> tft_beam.AnalyzeDataset(
            preprocessing_fn=self.preprocessing_fn
        )
[docs]
class WriteTFSchema(beam.PTransform):
    def __init__(
        self, schema: schema_pb2.Schema, target_uri: GcsUri, schema_descriptor: str
    ):
[docs]
    def expand(self, pbegin: PBegin) -> PDone:
        if not isinstance(pbegin, PBegin):
            raise TypeError(
                f"Input to {WriteTFSchema.__name__} transform "
                f"must be a PBegin but found {pbegin})"
            )
        return (
            pbegin
            | f"Create {self.schema_descriptor} schema PCollection"
            >> beam.Create([self.schema])
            | f"Write out {self.schema_descriptor} schema proto"
            >> beam.io.WriteToText(self.target_uri.uri, shard_name_template="")
        )
[docs]
def get_load_data_and_transform_pipeline_component(
    applied_task_identifier: AppliedTaskIdentifier,
    data_reference: DataReference,
    preprocessing_spec: Union[NodeDataPreprocessingSpec, EdgeDataPreprocessingSpec],
    transformed_features_info: TransformedFeaturesInfo,
    num_shards: int,
    custom_worker_image_uri: Optional[str] = None,
) -> beam.Pipeline:
    """
    Generate a Beam pipeline to conduct transformation, given a source feature table in BQ and an output path in GCS.
    """
    qualifier: str
    if isinstance(data_reference, EdgeDataReference):
        qualifier = f"-{data_reference.edge_type}-{data_reference.edge_usage_type}-"
    elif isinstance(data_reference, NodeDataReference):
        qualifier = f"-{data_reference.node_type}-"
    else:
        raise ValueError(
            f"data_reference must be of type {EdgeDataReference.__name__} or {NodeDataReference.__name__}, found: {type(data_reference)}"
        )
    job_name_suffix = f"{transformed_features_info.feature_type.value}-{qualifier}=feature-prep-{current_formatted_datetime().lower()}"
    # We disable type checking for this pipeline because it uses PTransforms with multiple PCollection inputs/outputs.
    # This is unsupported and hard to circumvent.  See https://lists.apache.org/thread/sok35vj08z8rb5drwoltkh5g06pbq19d
    # and https://lists.apache.org/thread/7cczwfz81lqrt431oh80yf3b0qwosf59.  Leaving type check enabled causes issues
    # with WriteTransformedTFRecords.
    resource_config = get_resource_config()
    if isinstance(preprocessing_spec, NodeDataPreprocessingSpec):
        data_preprocessor_config = (
            resource_config.preprocessor_config.node_preprocessor_config
        )
    elif isinstance(preprocessing_spec, EdgeDataPreprocessingSpec):
        data_preprocessor_config = (
            resource_config.preprocessor_config.edge_preprocessor_config
        )
    else:
        raise ValueError(
            f"Preprocessing spec has to be either {NodeDataPreprocessingSpec.__name__} "
            f"or {EdgeDataPreprocessingSpec.__name__}. Value given: {preprocessing_spec}"
        )
    options = init_beam_pipeline_options(
        applied_task_identifier=applied_task_identifier,
        job_name_suffix=job_name_suffix,
        component=GiGLComponents.DataPreprocessor,
        num_workers=data_preprocessor_config.num_workers,
        max_num_workers=data_preprocessor_config.max_num_workers,
        machine_type=data_preprocessor_config.machine_type,
        disk_size_gb=data_preprocessor_config.disk_size_gb,
        pipeline_type_check=False,
        resource_config=get_resource_config().get_resource_config_uri,
        custom_worker_image_uri=custom_worker_image_uri,
    )
    # pipeline start
    p = beam.Pipeline(options=options)
    with tft_beam.Context(
        temp_dir=transformed_features_info.tft_temp_directory_path.uri,
        use_deep_copy_optimization=False,
    ):
        raw_feature_spec = preprocessing_spec.feature_spec_fn()
        raw_data_schema: schema_pb2.Schema = schema_utils.schema_from_feature_spec(
            raw_feature_spec
        )
        beam_record_tfxio = tfx_bsl.tfxio.tf_example_record.TFExampleBeamRecord(
            physical_format="tfrecord", schema=raw_data_schema
        )
        raw_tensor_adapter_config = beam_record_tfxio.TensorAdapterConfig()
        # Ingest raw features from data reference and parse into TFXIO format for TFT to use.
        raw_features = p | IngestRawFeatures(
            data_reference=data_reference,
            feature_spec=raw_feature_spec,
            schema=raw_data_schema,
            beam_record_tfxio=beam_record_tfxio,
        )
        # Write out the TF schema of the raw features.
        _ = p | WriteTFSchema(
            schema=raw_data_schema,
            target_uri=transformed_features_info.raw_data_schema_file_path,
            schema_descriptor="raw",
        )
        # Run TFDV and generate statistics & Facets report visualization.
        # TODO(nshah-sc): revisit commenting this out in the future as needed.
        # _ = raw_features | GenerateAndVisualizeStats(
        #     facets_report_uri=transformed_features_info.visualized_facets_file_path,
        #     stats_output_uri=transformed_features_info.stats_file_path,
        # )
        # Read previous TransformFn assets from a pretrained path if specified, else build a new asset.
        pretrained_transform_fn: Optional[Tuple[Any, Any]] = None
        analyzed_transform_fn: Optional[Tuple[Any, Any]] = None
        should_use_existing_transform_fn: bool = (
            preprocessing_spec.pretrained_tft_model_uri is not None
        )
        if should_use_existing_transform_fn:
            logger.info(
                f"Will use pretrained TFTransform asset from {preprocessing_spec.pretrained_tft_model_uri}"
            )
            assert preprocessing_spec.pretrained_tft_model_uri is not None
            pretrained_transform_fn = p | ReadExistingTFTransformFn(
                tf_transform_directory=preprocessing_spec.pretrained_tft_model_uri
            )
        else:
            logger.info(f"Will build fresh TFTransform asset.")
            analyzed_transform_fn = raw_features | AnalyzeAndBuildTFTransformFn(
                tensor_adapter_config=raw_tensor_adapter_config,
                preprocessing_fn=preprocessing_spec.preprocessing_fn,
            )
        # Write TransformFn and associated transform metadata.
        resolved_transform_fn = pretrained_transform_fn or analyzed_transform_fn
        _ = resolved_transform_fn | "Write TransformFn" >> tft_beam.WriteTransformFn(
            transformed_features_info.transform_directory_path.uri
        )
        # Apply TransformFn over raw features
        transformed_features, transformed_metadata = (
            (raw_features, raw_tensor_adapter_config),
            resolved_transform_fn,
        ) | "Transform raw features dataset" >> tft_beam.TransformDataset(
            output_record_batches=True
        )
        # The transformed_features returned by tft_beam.TransformDataset is a
        # PCollection of Tuple[pa.RecordBatch, dict[str, pa.Array]]. The first
        # one are the transformed features. The second one are the passthrough
        # features, which doesn't apply here since we do not specify passthrough_keys
        # in tft_beam.Context. Hence we drop the second one in the tuple.
        transformed_features = transformed_features | "Extract RecordBatch" >> beam.Map(
            lambda element: element[0]
        )
        # The transformed_metadata returned by tft_beam.TransformDataset can only
        # be relied on for encoding purposes when reusing a pretrained transform_fn,
        # yet it could be inaccurate when using a new transform_fn built by
        # tft_beam.AnalyzeDataset. For the later case, we do not use transformed_metadata
        # returned by tft_beam.TransformDataset, but use deferred_metadata from
        # transform_fn instead.
        resolved_transformed_metadata = (
            transformed_metadata
            if should_use_existing_transform_fn
            else beam.pvalue.AsSingleton(analyzed_transform_fn[1].deferred_metadata)  # type: ignore
        )
        transformed_features | "Write tf record files" >> BetterWriteToTFRecord(
            file_path_prefix=transformed_features_info.transformed_features_file_prefix.uri,
            max_bytes_per_shard=int(2e8),  # 200mb,
            transformed_metadata=resolved_transformed_metadata,
            # TODO(mkolodner-sc): Right now, a non-zero value for num_shards overrides the max_bytes_per_shard condition. We need to implement
            # a solution where num_shards specified is just a minimum, causing the max_bytes_per_shard rule taking precedent over the num_shards rule. This will require
            # dynamically determining the number of shards produced by max_bytes_per_shard and setting it to be equal to min_num_shards if the value is less than it.
            num_shards=num_shards,
        )
        return p
