Source code for gigl.src.mocking.lib.mock_input_for_data_preprocessor

import json
import tempfile
from typing import Optional

import torch
from google.cloud import bigquery

from gigl.common import UriFactory
from gigl.common.logger import Logger
from gigl.src.common.types.graph_data import EdgeType, EdgeUsageType, NodeType
from gigl.src.common.utils.bq import BqUtils
from gigl.src.mocking.lib.constants import (
    get_example_task_edges_bq_table_path,
    get_example_task_nodes_bq_table_path,
)
from gigl.src.mocking.lib.feature_handling import get_feature_field_name
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo

[docs] logger = Logger()
def _generate_bigquery_assets_for_single_node_type( mocked_dataset_info: MockedDatasetInfo, version: str, node_type: NodeType, num_nodes: int, num_node_features: int, node_feats: torch.Tensor, node_labels: Optional[torch.Tensor], ): # Upload node features table tfh = tempfile.NamedTemporaryFile(delete=False, mode="w") node_feature_column_names = [ get_feature_field_name(n=n) for n in range(num_node_features) ] node_ids = torch.arange(num_nodes).reshape(-1, 1) node_labels = node_labels.reshape(-1, 1) if node_labels is not None else None with open(tfh.name, "w") as f: for i in range(num_nodes): node_id = node_ids[i] node_feat = node_feats[i] node_label = node_labels[i] if node_labels is not None else None row = {mocked_dataset_info.node_id_column_name: node_id.item()} for i, column_name in enumerate(node_feature_column_names): row.update({column_name: node_feat[i].item()}) if node_label: row.update( {mocked_dataset_info.node_label_column_name: node_label.item()} ) f.write(f"{json.dumps(row)}\n") node_features_schema = [ bigquery.SchemaField(mocked_dataset_info.node_id_column_name, "INTEGER"), ] + [ bigquery.SchemaField(column_name, "FLOAT") for column_name in node_feature_column_names ] if node_labels is not None: node_features_schema += [ bigquery.SchemaField(mocked_dataset_info.node_label_column_name, "INTEGER") ] nodes_bq_table = get_example_task_nodes_bq_table_path( task_name=mocked_dataset_info.name, version=version, node_type=node_type ) bq_utils = BqUtils() bq_utils.load_file_to_bq( source_path=UriFactory.create_uri(tfh.name), bq_path=nodes_bq_table, job_config=bigquery.LoadJobConfig( source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON, write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, schema=node_features_schema, ), retry=True, ) tfh.close() logger.info( f"{mocked_dataset_info.name} node data loaded to BQ table {nodes_bq_table}" ) def _generate_bigquery_assets_for_single_edge_type( mocked_dataset_info: MockedDatasetInfo, version: str, edge_type: EdgeType, edge_index: torch.Tensor, num_edge_features: int, edge_feats: Optional[torch.Tensor], edge_usage_type: EdgeUsageType, ): # Upload graph edges table tfh = tempfile.NamedTemporaryFile(delete=False, mode="w") edge_feature_column_names = [ get_feature_field_name(n=n) for n in range(num_edge_features) ] with open(tfh.name, "w") as f: for i, (src, dst) in enumerate(edge_index.T): row = { mocked_dataset_info.edge_src_column_name: src.item(), mocked_dataset_info.edge_dst_column_name: dst.item(), } if edge_feats is not None: edge_feat = edge_feats[i] for column_name, edge_feature in zip( edge_feature_column_names, edge_feat ): row.update({column_name: edge_feature.item()}) f.write(f"{json.dumps(row)}\n") edge_features_schema = [ bigquery.SchemaField(mocked_dataset_info.edge_src_column_name, "INTEGER"), bigquery.SchemaField(mocked_dataset_info.edge_dst_column_name, "INTEGER"), ] if edge_feats is not None: edge_features_schema += [ bigquery.SchemaField(column_name, "FLOAT") for column_name in edge_feature_column_names ] edges_bq_table = get_example_task_edges_bq_table_path( task_name=mocked_dataset_info.name, version=version, edge_type=edge_type, edge_usage_type=edge_usage_type, ) bq_utils = BqUtils() bq_utils.load_file_to_bq( source_path=UriFactory.create_uri(tfh.name), bq_path=edges_bq_table, job_config=bigquery.LoadJobConfig( source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON, write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, schema=edge_features_schema, ), retry=True, ) tfh.close() logger.info( f"{mocked_dataset_info.name} edge data loaded to BQ table {edges_bq_table}" )
[docs] def generate_bigquery_assets(mocked_dataset_info: MockedDatasetInfo, version: str): """ This generates a BQ table for each node type and edge type that exist in designated input. """ node_types = mocked_dataset_info.node_types num_nodes_per_type = mocked_dataset_info.num_nodes num_node_features_per_type = mocked_dataset_info.num_node_features for node_type in node_types: num_nodes = num_nodes_per_type[node_type] num_node_features = num_node_features_per_type[node_type] node_feats = mocked_dataset_info.node_feats[node_type] node_labels = ( None if mocked_dataset_info.node_labels is None else mocked_dataset_info.node_labels[node_type] ) _generate_bigquery_assets_for_single_node_type( mocked_dataset_info=mocked_dataset_info, version=version, node_type=node_type, num_nodes=num_nodes, num_node_features=num_node_features, node_feats=node_feats, node_labels=node_labels, ) edge_types = mocked_dataset_info.edge_types edge_index_per_type = mocked_dataset_info.edge_index num_edge_features_per_type = mocked_dataset_info.num_edge_features for edge_type in edge_types: edge_index = edge_index_per_type[edge_type] num_edge_features = num_edge_features_per_type[edge_type] edge_feats = ( mocked_dataset_info.edge_feats[edge_type] if mocked_dataset_info.edge_feats else None ) _generate_bigquery_assets_for_single_edge_type( mocked_dataset_info=mocked_dataset_info, version=version, edge_type=edge_type, edge_index=edge_index, num_edge_features=num_edge_features, edge_feats=edge_feats, edge_usage_type=EdgeUsageType.MAIN, ) if ( mocked_dataset_info.user_defined_edge_index and edge_type in mocked_dataset_info.user_defined_edge_index ): for ( edge_usage_type, user_def_edge_index, ) in mocked_dataset_info.user_defined_edge_index[edge_type].items(): user_defined_edge_feats = ( mocked_dataset_info.user_defined_edge_feats[edge_type][ edge_usage_type ] if mocked_dataset_info.user_defined_edge_feats and edge_type in mocked_dataset_info.user_defined_edge_feats else None ) num_user_def_edge_features = ( mocked_dataset_info.num_user_def_edge_features[edge_type][ edge_usage_type ] ) _generate_bigquery_assets_for_single_edge_type( mocked_dataset_info=mocked_dataset_info, version=version, edge_type=edge_type, edge_index=user_def_edge_index, num_edge_features=num_user_def_edge_features, edge_feats=user_defined_edge_feats, edge_usage_type=edge_usage_type, )