Source code for gigl.distributed.utils.serialized_graph_metadata_translator

from typing import Dict, Optional, Tuple, Union

from gigl.common import UriFactory
from gigl.common.data.dataloaders import SerializedTFRecordInfo
from gigl.common.data.load_torch_tensors import SerializedGraphMetadata
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.types.pb_wrappers.preprocessed_metadata import (
    PreprocessedMetadataPbWrapper,
)
from gigl.src.data_preprocessor.lib.types import FeatureSpecDict
from gigl.types.graph import to_homogeneous
from snapchat.research.gbml.preprocessed_metadata_pb2 import PreprocessedMetadata


def _build_serialized_tfrecord_entity_info(
    preprocessed_metadata: Union[
        PreprocessedMetadata.NodeMetadataOutput, PreprocessedMetadata.EdgeMetadataInfo
    ],
    feature_spec_dict: FeatureSpecDict,
    entity_key: Union[str, Tuple[str, str]],
    tfrecord_uri_pattern: str,
) -> SerializedTFRecordInfo:
    """
    Populates a SerializedTFRecordInfo field from provided arguments for either a node or edge entity of a single node/edge type.
    Args:
        preprocessed_metadata(Union[
            PreprocessedMetadata.NodeMetadataOutput, PreprocessedMetadata.EdgeMetadataInfo
        ]): Preprocessed metadata pb for either NodeMetadataOutput or EdgeMetadataInfo
        feature_spec_dict (FeatureSpecDict): Feature spec to register to SerializedTFRecordInfo
        entity_key (Union[str, Tuple[str, str]]): Entity key to register to SerializedTFRecordInfo, is a str if Node entity or Tuple[str, str] if Edge entity
        tfrecord_uri_pattern (str): Regex pattern for loading serialized tf records
    Returns:
        SerializedTFRecordInfo: Stored metadata for current entity
    """
    return SerializedTFRecordInfo(
        tfrecord_uri_prefix=UriFactory.create_uri(
            preprocessed_metadata.tfrecord_uri_prefix
        ),
        feature_keys=list(preprocessed_metadata.feature_keys),
        feature_spec=feature_spec_dict,
        feature_dim=preprocessed_metadata.feature_dim,
        entity_key=entity_key,
        tfrecord_uri_pattern=tfrecord_uri_pattern,
    )


[docs] def convert_pb_to_serialized_graph_metadata( preprocessed_metadata_pb_wrapper: PreprocessedMetadataPbWrapper, graph_metadata_pb_wrapper: GraphMetadataPbWrapper, tfrecord_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$", ) -> SerializedGraphMetadata: """ Populates a SerializedGraphMetadata field from PreprocessedMetadataPbWrapper and GraphMetadataPbWrapper, containing information for loading tensors for all entities and node/edge types. Args: preprocessed_metadata_pb_wrapper (PreprocessedMetadataPbWrapper): Preprocessed Metadata Pb Wrapper to translate into SerializedGraphMetadata graph_metadata_pb_wrapper (GraphMetadataPbWrapper): Graph Metadata Pb Wrapper to translate into Dataset Metadata tfrecord_uri_pattern (str): Regex pattern for loading serialized tf records Returns: SerializedGraphMetadata: Dataset Metadata for all entity and node/edge types. """ node_entity_info: Dict[NodeType, SerializedTFRecordInfo] = {} edge_entity_info: Dict[EdgeType, SerializedTFRecordInfo] = {} positive_label_entity_info: Dict[EdgeType, Optional[SerializedTFRecordInfo]] = {} negative_label_entity_info: Dict[EdgeType, Optional[SerializedTFRecordInfo]] = {} preprocessed_metadata_pb = preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb for node_type in graph_metadata_pb_wrapper.node_types: condensed_node_type = ( graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[node_type] ) node_metadata = ( preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[ condensed_node_type ] ) node_feature_spec_dict = ( preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_schema_map[ condensed_node_type ].feature_spec ) node_key = node_metadata.node_id_key node_entity_info[node_type] = _build_serialized_tfrecord_entity_info( preprocessed_metadata=node_metadata, feature_spec_dict=node_feature_spec_dict, entity_key=node_key, tfrecord_uri_pattern=tfrecord_uri_pattern, ) for edge_type in graph_metadata_pb_wrapper.edge_types: condensed_edge_type = ( graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[edge_type] ) edge_metadata = ( preprocessed_metadata_pb.condensed_edge_type_to_preprocessed_metadata[ condensed_edge_type ] ) edge_feature_spec_dict = ( preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_schema_map[ condensed_edge_type ].feature_spec ) edge_key = ( edge_metadata.src_node_id_key, edge_metadata.dst_node_id_key, ) edge_entity_info[edge_type] = _build_serialized_tfrecord_entity_info( preprocessed_metadata=edge_metadata.main_edge_info, feature_spec_dict=edge_feature_spec_dict, entity_key=edge_key, tfrecord_uri_pattern=tfrecord_uri_pattern, ) if preprocessed_metadata_pb_wrapper.has_pos_edge_features( condensed_edge_type=condensed_edge_type ): pos_edge_feature_spec_dict = preprocessed_metadata_pb_wrapper.condensed_edge_type_to_pos_edge_feature_schema_map[ condensed_edge_type ].feature_spec positive_label_entity_info[ edge_type ] = _build_serialized_tfrecord_entity_info( preprocessed_metadata=edge_metadata.positive_edge_info, feature_spec_dict=pos_edge_feature_spec_dict, entity_key=edge_key, tfrecord_uri_pattern=tfrecord_uri_pattern, ) else: positive_label_entity_info[edge_type] = None if preprocessed_metadata_pb_wrapper.has_hard_neg_edge_features( condensed_edge_type=condensed_edge_type ): hard_neg_edge_feature_spec_dict = preprocessed_metadata_pb_wrapper.condensed_edge_type_to_hard_neg_edge_feature_schema_map[ condensed_edge_type ].feature_spec negative_label_entity_info[ edge_type ] = _build_serialized_tfrecord_entity_info( preprocessed_metadata=edge_metadata.negative_edge_info, feature_spec_dict=hard_neg_edge_feature_spec_dict, entity_key=edge_key, tfrecord_uri_pattern=tfrecord_uri_pattern, ) else: negative_label_entity_info[edge_type] = None if not graph_metadata_pb_wrapper.is_heterogeneous: # If our input is homogeneous, we remove the node/edge type component of the metadata fields. return SerializedGraphMetadata( node_entity_info=to_homogeneous(node_entity_info), edge_entity_info=to_homogeneous(edge_entity_info), positive_label_entity_info=to_homogeneous(positive_label_entity_info), negative_label_entity_info=to_homogeneous(negative_label_entity_info), ) else: return SerializedGraphMetadata( node_entity_info=node_entity_info, edge_entity_info=edge_entity_info, positive_label_entity_info=positive_label_entity_info if not all( entity_info is None for entity_info in positive_label_entity_info.values() ) else None, negative_label_entity_info=negative_label_entity_info if not all( entity_info is None for entity_info in negative_label_entity_info.values() ) else None, )