Source code for gigl.src.mocking.lib.mock_input_for_subgraph_sampler

import tempfile
from dataclasses import dataclass
from typing import Dict, List, Optional

import tensorflow as tf
import torch
from tensorflow_transform.tf_metadata import schema_utils

from gigl.common import GcsUri, LocalUri, UriFactory
from gigl.common.logger import Logger
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.src.common.types.graph_data import (
    CondensedEdgeType,
    CondensedNodeType,
    EdgeType,
    EdgeUsageType,
    NodeType,
)
from gigl.src.common.utils.file_loader import FileLoader
from gigl.src.data_preprocessor.lib.transform.tf_value_encoder import TFValueEncoder
from gigl.src.data_preprocessor.lib.types import FeatureSpecDict, InstanceDict
from gigl.src.mocking.lib.constants import (
    get_example_task_edge_features_gcs_dir,
    get_example_task_edge_features_schema_gcs_path,
    get_example_task_node_features_gcs_dir,
    get_example_task_node_features_schema_gcs_path,
)
from gigl.src.mocking.lib.feature_handling import get_feature_field_name
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo
from snapchat.research.gbml import gbml_config_pb2, preprocessed_metadata_pb2

[docs] logger = Logger()
@dataclass class _PreprocessMetadata: features_uri: GcsUri schema_uri: GcsUri feature_cols: List[str] @dataclass class _NodePreprocessMetadata(_PreprocessMetadata): id_col: str label_col: Optional[str] = None @dataclass class _EdgePreprocessMetadata(_PreprocessMetadata): src_id_col: str dst_id_col: str class _InstanceDictToTFExample: """ Uses a feature spec to process a raw instance dict (read from some tabular data) as a TFExample. """ def __init__(self, feature_spec: FeatureSpecDict): self.feature_spec = feature_spec def process(self, element: InstanceDict) -> bytes: # Each row is a single instance dict from the original tabular input (BQ, GCS, etc.) example = dict() for key in self.feature_spec.keys(): # prepare each value associated with a key that appears in the feature_spec. # only the instance dict keys the user specifies wanting in the feature_spec will pass through here value = element[key] if value is None: logger.debug(f"Found key {key} with missing value in sample {element}") example[key] = TFValueEncoder.encode_value_as_feature( value=value, dtype=self.feature_spec[key].dtype ) example_proto = tf.train.Example(features=tf.train.Features(feature=example)) serialized_proto = example_proto.SerializeToString() return serialized_proto def _generate_preprocessed_node_tfrecord_data( data: MockedDatasetInfo, version: str, node_type: NodeType, num_node_features: int, node_features: torch.Tensor, node_labels: Optional[torch.Tensor], ) -> _NodePreprocessMetadata: feature_names: List[str] = [ get_feature_field_name(n=i) for i in range(num_node_features) ] feature_spec_dict = { data.node_id_column_name: tf.io.FixedLenFeature(shape=[], dtype=tf.int64) } feature_spec_dict.update( { col: tf.io.FixedLenFeature(shape=[], dtype=tf.float32) for col in feature_names } ) if node_labels is not None: feature_spec_dict.update( { data.node_label_column_name: tf.io.FixedLenFeature( shape=[], dtype=tf.int64 ) } ) id2tfe_encoder = _InstanceDictToTFExample(feature_spec=feature_spec_dict) tfrecords = [] instance_dict_feats: Dict[str, torch.Tensor] for node_id, node_feature_values in enumerate(node_features): instance_dict_feats = {data.node_id_column_name: torch.LongTensor([node_id])} instance_dict_feats.update( { feat_name: feat_value for feat_name, feat_value in zip(feature_names, node_feature_values) } ) if node_labels is not None: instance_dict_feats.update( {data.node_label_column_name: node_labels[node_id]} ) tfrecord_bytes = id2tfe_encoder.process(element=instance_dict_feats) tfrecords.append(tfrecord_bytes) # Write features to GCS. features_path = get_example_task_node_features_gcs_dir( task_name=data.name, version=version, node_type=node_type ) with tf.io.TFRecordWriter( GcsUri.join(features_path, "data.tfrecord").uri ) as writer: for tfrecord in tfrecords: writer.write(tfrecord) logger.info( f"Wrote preprocessed node TFRecords for type {node_type} to prefix {features_path.uri}" ) # Write schema to GCS. node_schema_uri = get_example_task_node_features_schema_gcs_path( task_name=data.name, version=version, node_type=node_type ) node_schema = schema_utils.schema_from_feature_spec(feature_spec=feature_spec_dict) file_loader = FileLoader() temp_file_handle = tempfile.NamedTemporaryFile() with open(temp_file_handle.name, "w") as f: f.write(repr(node_schema)) file_loader.load_file( file_uri_src=LocalUri(temp_file_handle.name), file_uri_dst=node_schema_uri, ) logger.info( f"Wrote preprocessed node TFRecords schema for type {node_type} to prefix {node_schema_uri.uri}" ) return _NodePreprocessMetadata( features_uri=features_path, schema_uri=node_schema_uri, feature_cols=feature_names, id_col=data.node_id_column_name, label_col=data.node_label_column_name if node_labels is not None else None, ) def _generate_preprocessed_edge_tfrecord_data( data: MockedDatasetInfo, version: str, edge_type: EdgeType, edge_index: torch.Tensor, num_edge_features: int, edge_features: Optional[torch.Tensor], edge_usage_type: EdgeUsageType, ) -> _EdgePreprocessMetadata: feature_names: List[str] = [ get_feature_field_name(n=i) for i in range(num_edge_features) ] feature_spec_dict = { data.edge_src_column_name: tf.io.FixedLenFeature(shape=[], dtype=tf.int64), data.edge_dst_column_name: tf.io.FixedLenFeature(shape=[], dtype=tf.int64), } feature_spec_dict.update( { col: tf.io.FixedLenFeature(shape=[], dtype=tf.float32) for col in feature_names } ) id2tfe_encoder = _InstanceDictToTFExample(feature_spec=feature_spec_dict) tfrecords = [] for edge_id, (src_id, dst_id) in enumerate(zip(edge_index[0, :], edge_index[1, :])): instance_dict_feats = { data.edge_src_column_name: src_id, data.edge_dst_column_name: dst_id, } if edge_features is not None: edge_feature_values = edge_features[edge_id, :] instance_dict_feats.update( { feat_name: feat_value for feat_name, feat_value in zip(feature_names, edge_feature_values) } ) tfrecord_bytes = id2tfe_encoder.process(element=instance_dict_feats) tfrecords.append(tfrecord_bytes) # Write features to GCS. features_path = get_example_task_edge_features_gcs_dir( task_name=data.name, version=version, edge_type=edge_type, edge_usage_type=edge_usage_type, ) with tf.io.TFRecordWriter( GcsUri.join(features_path, "data.tfrecord").uri ) as writer: for tfrecord in tfrecords: writer.write(tfrecord) logger.info( f"Wrote preprocessed edge TFRecords for type {edge_type} to prefix {features_path.uri}" ) # Write schema to GCS. edge_schema_uri = get_example_task_edge_features_schema_gcs_path( task_name=data.name, version=version, edge_type=edge_type, edge_usage_type=edge_usage_type, ) edge_schema = schema_utils.schema_from_feature_spec(feature_spec=feature_spec_dict) file_loader = FileLoader() temp_file_handle = tempfile.NamedTemporaryFile() with open(temp_file_handle.name, "w") as f: f.write(repr(edge_schema)) file_loader.load_file( file_uri_src=LocalUri(temp_file_handle.name), file_uri_dst=edge_schema_uri, ) logger.info( f"Wrote preprocessed edge TFRecords schema for type {edge_type} to {edge_schema_uri.uri}" ) return _EdgePreprocessMetadata( features_uri=features_path, schema_uri=edge_schema_uri, feature_cols=feature_names, src_id_col=data.edge_src_column_name, dst_id_col=data.edge_dst_column_name, )
[docs] def generate_preprocessed_tfrecord_data( mocked_dataset_info: MockedDatasetInfo, version: str, gbml_config_pb: gbml_config_pb2.GbmlConfig, ): graph_metadata_pb_wrapper = mocked_dataset_info.graph_metadata_pb_wrapper num_features_by_node_type = mocked_dataset_info.num_node_features node_features_by_node_type = mocked_dataset_info.node_feats node_labels_by_node_type = mocked_dataset_info.node_labels node_types = mocked_dataset_info.node_types condensed_node_type_to_preprocessed_metadata: Dict[ CondensedNodeType, preprocessed_metadata_pb2.PreprocessedMetadata.NodeMetadataOutput, ] = dict() for node_type in node_types: condensed_node_type = ( graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[node_type] ) num_features = num_features_by_node_type[node_type] node_features = node_features_by_node_type[node_type] if not num_features: continue node_labels = ( node_labels_by_node_type[node_type] if node_labels_by_node_type is not None else None ) node_preprocess_metadata = _generate_preprocessed_node_tfrecord_data( data=mocked_dataset_info, version=version, node_type=node_type, num_node_features=num_features, node_features=node_features, node_labels=node_labels, ) condensed_node_type_to_preprocessed_metadata[ condensed_node_type ] = preprocessed_metadata_pb2.PreprocessedMetadata.NodeMetadataOutput( node_id_key=node_preprocess_metadata.id_col, feature_keys=node_preprocess_metadata.feature_cols, label_keys=[node_preprocess_metadata.label_col] if node_preprocess_metadata.label_col is not None else None, # type: ignore tfrecord_uri_prefix=node_preprocess_metadata.features_uri.uri, schema_uri=node_preprocess_metadata.schema_uri.uri, feature_dim=num_features, ) num_features_by_edge_type = mocked_dataset_info.num_edge_features edge_features_by_edge_type = mocked_dataset_info.edge_feats edge_index_by_edge_type = mocked_dataset_info.edge_index edge_types = mocked_dataset_info.edge_types condensed_edge_type_to_preprocessed_metadata: Dict[ CondensedEdgeType, preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataOutput, ] = dict() for edge_type in edge_types: condensed_edge_type = ( graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[edge_type] ) num_features = num_features_by_edge_type[edge_type] edge_features = ( edge_features_by_edge_type[edge_type] if edge_features_by_edge_type is not None else None ) edge_index = edge_index_by_edge_type[edge_type] edge_preprocess_metadata = _generate_preprocessed_edge_tfrecord_data( data=mocked_dataset_info, version=version, edge_type=edge_type, edge_index=edge_index, num_edge_features=num_features, edge_features=edge_features, edge_usage_type=EdgeUsageType.MAIN, ) main_edge_metadata_info_pb = ( preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataInfo( feature_keys=edge_preprocess_metadata.feature_cols, tfrecord_uri_prefix=edge_preprocess_metadata.features_uri.uri, schema_uri=edge_preprocess_metadata.schema_uri.uri, feature_dim=num_features, ) ) if ( mocked_dataset_info.user_defined_edge_index and edge_type in mocked_dataset_info.user_defined_edge_index ): edge_preprocess_metadata_pb_dict = {} for ( user_def_label, user_def_edge_index, ) in mocked_dataset_info.user_defined_edge_index[edge_type].items(): num_edge_feats = mocked_dataset_info.num_user_def_edge_features[ edge_type ][user_def_label] user_defined_edge_feats = ( mocked_dataset_info.user_defined_edge_feats[edge_type][ user_def_label ] if mocked_dataset_info.user_defined_edge_feats and edge_type in mocked_dataset_info.user_defined_edge_feats else None ) user_def_edge_preprocess_metadata = ( _generate_preprocessed_edge_tfrecord_data( data=mocked_dataset_info, version=version, edge_type=edge_type, edge_index=user_def_edge_index, num_edge_features=num_edge_feats, edge_features=user_defined_edge_feats, edge_usage_type=user_def_label, ) ) user_def_edge_metadata_info_pb = preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataInfo( feature_keys=user_def_edge_preprocess_metadata.feature_cols, tfrecord_uri_prefix=user_def_edge_preprocess_metadata.features_uri.uri, schema_uri=user_def_edge_preprocess_metadata.schema_uri.uri, feature_dim=num_edge_feats, ) edge_preprocess_metadata_pb_dict[ user_def_label ] = user_def_edge_metadata_info_pb condensed_edge_type_to_preprocessed_metadata[ condensed_edge_type ] = preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataOutput( src_node_id_key=edge_preprocess_metadata.src_id_col, dst_node_id_key=edge_preprocess_metadata.dst_id_col, main_edge_info=main_edge_metadata_info_pb, positive_edge_info=edge_preprocess_metadata_pb_dict.get( EdgeUsageType.POSITIVE, None ), negative_edge_info=edge_preprocess_metadata_pb_dict.get( EdgeUsageType.NEGATIVE, None ), ) else: condensed_edge_type_to_preprocessed_metadata[ condensed_edge_type ] = preprocessed_metadata_pb2.PreprocessedMetadata.EdgeMetadataOutput( src_node_id_key=edge_preprocess_metadata.src_id_col, dst_node_id_key=edge_preprocess_metadata.dst_id_col, main_edge_info=main_edge_metadata_info_pb, ) # Assemble Preprocessed Metadata pb and write out. preprocessed_metadata_pb = preprocessed_metadata_pb2.PreprocessedMetadata() for ( condensed_node_type, node_metadata_output, ) in condensed_node_type_to_preprocessed_metadata.items(): preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[ condensed_node_type ].CopyFrom(node_metadata_output) for ( condensed_edge_type, edge_metadata_output, ) in condensed_edge_type_to_preprocessed_metadata.items(): preprocessed_metadata_pb.condensed_edge_type_to_preprocessed_metadata[ condensed_edge_type ].CopyFrom(edge_metadata_output) preprocessed_metadata_uri = UriFactory.create_uri( gbml_config_pb.shared_config.preprocessed_metadata_uri ) proto_utils = ProtoUtils() proto_utils.write_proto_to_yaml( proto=preprocessed_metadata_pb, uri=preprocessed_metadata_uri ) logger.info( f"Wrote preprocessed metadata proto to {preprocessed_metadata_uri.uri}." )