Source code for gigl.src.data_preprocessor.lib.transform.tf_value_encoder

from typing import Any, AnyStr, List, Union

import tensorflow as tf


[docs] class TFValueEncoder: @staticmethod
[docs] def get_value_to_impute(dtype: tf.dtypes.DType) -> Union[int, str, float]: """ Returns the default value to use for a missing field. :param dtype: :return: """ if dtype.is_integer: return 0 elif dtype.is_bool: return 0 elif dtype.is_floating: return 0.0 else: return "MISSING"
@staticmethod def __bytes_values_to_tf_feature(value: List[AnyStr]) -> tf.train.Feature: """ Returns a bytes_list from a string / byte (or list of such). """ if isinstance(value, type(tf.constant(0))): value = ( value.numpy() ) # BytesList won't unpack a string from an EagerTensor. value_bytes: List[bytes] = [] for v in value: if isinstance(v, str): value_bytes.append(v.encode("utf-8")) elif isinstance(v, bytes): value_bytes.append(v) else: raise TypeError(f"Got object of type {type(v)} (must be bytes or str)") return tf.train.Feature(bytes_list=tf.train.BytesList(value=value_bytes)) @staticmethod def __float_values_to_tf_feature(value: List[float]) -> tf.train.Feature: """ Returns a float_list from a float / double (or list of such). """ return tf.train.Feature(float_list=tf.train.FloatList(value=value)) @staticmethod def __int_values_to_tf_feature(value: List[int]) -> tf.train.Feature: """ Returns an int64_list from a bool / enum / int / uint (or list of such). """ return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) @staticmethod
[docs] def encode_value_as_feature(value: Any, dtype: tf.dtypes.DType) -> tf.train.Feature: """ Try to encode a given "raw value" as a tf.train.Feature of the intended type. Imputes missing values according to defaults for their dtype. :param value: :param dtype: :return: """ # prepare value if value is None: value = TFValueEncoder.get_value_to_impute(dtype=dtype) if not isinstance(value, list): value = [value] # encode value if dtype.is_integer or dtype.is_bool: tf_feature = TFValueEncoder.__int_values_to_tf_feature(value=value) elif dtype.is_floating: tf_feature = TFValueEncoder.__float_values_to_tf_feature(value=value) else: tf_feature = TFValueEncoder.__bytes_values_to_tf_feature(value=value) return tf_feature