Source code for gigl.src.mocking.lib.mock_input_for_split_generator

from typing import List

import torch_geometric.transforms as T
from torch_geometric.data import HeteroData

from gigl.common import UriFactory
from gigl.common.logger import Logger
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.src.mocking.lib import pyg_to_training_samples, tfrecord_io
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo
from snapchat.research.gbml import gbml_config_pb2, training_samples_schema_pb2

[docs] logger = Logger()
[docs] def build_and_write_supervised_node_classification_subgraph_samples_from_mocked_dataset_info( mocked_dataset_info: MockedDatasetInfo, root_node_type: NodeType, gbml_config_pb: gbml_config_pb2.GbmlConfig, ) -> HeteroData: hetero_data = pyg_to_training_samples.build_pyg_heterodata_from_mocked_dataset_info( mocked_dataset_info=mocked_dataset_info ) samples: List[ training_samples_schema_pb2.SupervisedNodeClassificationSample ] = pyg_to_training_samples.build_supervised_node_classification_samples_from_pyg_heterodata( hetero_data=hetero_data, root_node_type=root_node_type, graph_metadata_pb_wrapper=mocked_dataset_info.graph_metadata_pb_wrapper, ) # Write out to GbmlConfig-specified paths output_paths = ( gbml_config_pb.shared_config.flattened_graph_metadata.supervised_node_classification_output ) labeled_sample_tfrecord_uri = UriFactory.create_uri( output_paths.labeled_tfrecord_uri_prefix ) tfrecord_io.write_pb_tfrecord_shards_to_uri( pb_samples=samples, uri_prefix=labeled_sample_tfrecord_uri, sample_type_for_logging="labeled SNC", ) return hetero_data