Source code for gigl.distributed.utils.serialized_graph_metadata_translator
from typing import Tuple, Union
from gigl.common import UriFactory
from gigl.common.data.dataloaders import SerializedTFRecordInfo
from gigl.common.data.load_torch_tensors import SerializedGraphMetadata
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.types.pb_wrappers.preprocessed_metadata import (
    PreprocessedMetadataPbWrapper,
)
from gigl.src.data_preprocessor.lib.types import FeatureSpecDict
from gigl.types.graph import to_homogeneous
from snapchat.research.gbml.preprocessed_metadata_pb2 import PreprocessedMetadata
def _build_serialized_tfrecord_entity_info(
    preprocessed_metadata: Union[
        PreprocessedMetadata.NodeMetadataOutput, PreprocessedMetadata.EdgeMetadataInfo
    ],
    feature_spec_dict: FeatureSpecDict,
    entity_key: Union[str, Tuple[str, str]],
    tfrecord_uri_pattern: str,
) -> SerializedTFRecordInfo:
    """
    Populates a SerializedTFRecordInfo field from provided arguments for either a node or edge entity of a single node/edge type.
    Args:
        preprocessed_metadata(Union[
            PreprocessedMetadata.NodeMetadataOutput, PreprocessedMetadata.EdgeMetadataInfo
        ]): Preprocessed metadata pb for either NodeMetadataOutput or EdgeMetadataInfo
        feature_spec_dict (FeatureSpecDict): Feature spec to register to SerializedTFRecordInfo
        entity_key (Union[str, Tuple[str, str]]): Entity key to register to SerializedTFRecordInfo, is a str if Node entity or Tuple[str, str] if Edge entity
        tfrecord_uri_pattern (str): Regex pattern for loading serialized tf records
    Returns:
        SerializedTFRecordInfo: Stored metadata for current entity
    """
    return SerializedTFRecordInfo(
        tfrecord_uri_prefix=UriFactory.create_uri(
            preprocessed_metadata.tfrecord_uri_prefix
        ),
        feature_keys=list(preprocessed_metadata.feature_keys),
        feature_spec=feature_spec_dict,
        feature_dim=preprocessed_metadata.feature_dim,
        entity_key=entity_key,
        label_keys=list(preprocessed_metadata.label_keys),
        tfrecord_uri_pattern=tfrecord_uri_pattern,
    )
[docs]
def convert_pb_to_serialized_graph_metadata(
    preprocessed_metadata_pb_wrapper: PreprocessedMetadataPbWrapper,
    graph_metadata_pb_wrapper: GraphMetadataPbWrapper,
    tfrecord_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$",
) -> SerializedGraphMetadata:
    """
    Populates a SerializedGraphMetadata field from PreprocessedMetadataPbWrapper and GraphMetadataPbWrapper, containing information for loading tensors for all entities and node/edge types.
    Args:
        preprocessed_metadata_pb_wrapper (PreprocessedMetadataPbWrapper): Preprocessed Metadata Pb Wrapper to translate into SerializedGraphMetadata
        graph_metadata_pb_wrapper (GraphMetadataPbWrapper): Graph Metadata Pb Wrapper to translate into Dataset Metadata
        tfrecord_uri_pattern (str): Regex pattern for loading serialized tf records
    Returns:
        SerializedGraphMetadata: Dataset Metadata for all entity and node/edge types.
    """
    node_entity_info: dict[NodeType, SerializedTFRecordInfo] = {}
    edge_entity_info: dict[EdgeType, SerializedTFRecordInfo] = {}
    positive_label_entity_info: dict[EdgeType, SerializedTFRecordInfo] = {}
    negative_label_entity_info: dict[EdgeType, SerializedTFRecordInfo] = {}
    preprocessed_metadata_pb = preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb
    for node_type in graph_metadata_pb_wrapper.node_types:
        condensed_node_type = (
            graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[node_type]
        )
        node_metadata = (
            preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[
                condensed_node_type
            ]
        )
        node_feature_spec_dict = (
            preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_schema_map[
                condensed_node_type
            ].feature_spec
        )
        node_key = node_metadata.node_id_key
        node_entity_info[node_type] = _build_serialized_tfrecord_entity_info(
            preprocessed_metadata=node_metadata,
            feature_spec_dict=node_feature_spec_dict,
            entity_key=node_key,
            tfrecord_uri_pattern=tfrecord_uri_pattern,
        )
    for edge_type in graph_metadata_pb_wrapper.edge_types:
        condensed_edge_type = (
            graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[edge_type]
        )
        edge_metadata = (
            preprocessed_metadata_pb.condensed_edge_type_to_preprocessed_metadata[
                condensed_edge_type
            ]
        )
        edge_key = (
            edge_metadata.src_node_id_key,
            edge_metadata.dst_node_id_key,
        )
        if edge_metadata.HasField("main_edge_info"):
            edge_feature_spec_dict = preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_schema_map[
                condensed_edge_type
            ].feature_spec
            edge_entity_info[edge_type] = _build_serialized_tfrecord_entity_info(
                preprocessed_metadata=edge_metadata.main_edge_info,
                feature_spec_dict=edge_feature_spec_dict,
                entity_key=edge_key,
                tfrecord_uri_pattern=tfrecord_uri_pattern,
            )
        if edge_metadata.HasField("positive_edge_info"):
            pos_edge_feature_spec_dict = preprocessed_metadata_pb_wrapper.condensed_edge_type_to_pos_edge_feature_schema_map[
                condensed_edge_type
            ].feature_spec
            positive_label_entity_info[
                edge_type
            ] = _build_serialized_tfrecord_entity_info(
                preprocessed_metadata=edge_metadata.positive_edge_info,
                feature_spec_dict=pos_edge_feature_spec_dict,
                entity_key=edge_key,
                tfrecord_uri_pattern=tfrecord_uri_pattern,
            )
        if edge_metadata.HasField("negative_edge_info"):
            hard_neg_edge_feature_spec_dict = preprocessed_metadata_pb_wrapper.condensed_edge_type_to_hard_neg_edge_feature_schema_map[
                condensed_edge_type
            ].feature_spec
            negative_label_entity_info[
                edge_type
            ] = _build_serialized_tfrecord_entity_info(
                preprocessed_metadata=edge_metadata.negative_edge_info,
                feature_spec_dict=hard_neg_edge_feature_spec_dict,
                entity_key=edge_key,
                tfrecord_uri_pattern=tfrecord_uri_pattern,
            )
    if not graph_metadata_pb_wrapper.is_heterogeneous:
        # If our input is homogeneous, we remove the node/edge type component of the metadata fields.
        return SerializedGraphMetadata(
            node_entity_info=to_homogeneous(node_entity_info),
            edge_entity_info=to_homogeneous(edge_entity_info),
            positive_label_entity_info=to_homogeneous(positive_label_entity_info)
            if len(positive_label_entity_info) > 0
            else None,
            negative_label_entity_info=to_homogeneous(negative_label_entity_info)
            if len(negative_label_entity_info) > 0
            else None,
        )
    else:
        return SerializedGraphMetadata(
            node_entity_info=node_entity_info,
            edge_entity_info=edge_entity_info,
            positive_label_entity_info=positive_label_entity_info
            if len(positive_label_entity_info) > 0
            else None,
            negative_label_entity_info=negative_label_entity_info
            if len(negative_label_entity_info) > 0
            else None,
        )
