(Optional) Fetch MAG240M Data into your own project#
Latest version of this notebook can be found on github
Install reqs#
!pip install ogb
import multiprocessing
import os
from io import BytesIO
from typing import Optional
import numpy as np
from common import MAG240_DATASET_PATH
from fastavro import writer
from google.cloud import bigquery, storage
from ogb.lsc import MAG240MDataset
from gigl.common import GcsUri
from gigl.common.utils.gcs import GcsUtils
from gigl.src.common.utils.bq import BqUtils
BASE_GCS_BUCKET_PATH = (
"CHANGE THIS WITH YOUR GCS PATH" # WARN: CHANGE THIS WITH YOUR GCS PATH
)
BASE_BQ_PATH = "CHANGE THIS WITH YOUR BQ PATH" # WARN: CHANGE THIS WITH YOUR BQ PATH
def get_gcs_path_for_asset(asset_name):
return f"{BASE_GCS_BUCKET_PATH}/{asset_name}"
def get_bq_path_for_asset(asset_name):
return f"{BASE_BQ_PATH}_{asset_name}"
def fetch_dataset() -> MAG240MDataset:
# !pip install -U ogb # We will pull the dataset from ogb: https://ogb.stanford.edu/docs/lsc/
# Example of how this dataset can be used: https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/gnn.py
# WARNING: This code block can take hours to even days to run if the dataset is not available already locally
# This is likely because upstream servers are slow as even on cloud dinstances with 100GB/s + network,
# the download is slow and took half a day.
print("Fetching MAG240M dataset and storing it in ", MAG240_DATASET_PATH)
if not os.path.exists(MAG240_DATASET_PATH):
os.makedirs(MAG240_DATASET_PATH)
dataset = MAG240MDataset(root=MAG240_DATASET_PATH)
return dataset
dataset = fetch_dataset()
author_writes_paper = dataset.edge_index("author", "paper")
author_affiliated_with_institution = dataset.edge_index("author", "institution")
paper_cites_paper = dataset.edge_index("paper", "paper")
paper_feat = dataset.all_paper_feat
paper_label = dataset.all_paper_label
paper_year = dataset.all_paper_year
# Define all the Avro schemas
fields_node_paper = [
{"name": "paper", "type": "int"},
]
for featIdx in range(dataset.num_paper_features):
fields_node_paper.append({"name": f"feat_{featIdx}", "type": "float"})
schema_node_paper = {"type": "record", "name": "Paper", "fields": fields_node_paper}
schema_node_paper_label = {
"type": "record",
"name": "PaperLabel",
"fields": [{"name": "paper", "type": "int"}, {"name": "label", "type": "float"}],
}
schema_node_paper_year = {
"type": "record",
"name": "PaperYear",
"fields": [{"name": "paper", "type": "int"}, {"name": "year", "type": "int"}],
}
schema_edge_author_writes_paper = {
"type": "record",
"name": "AuthorWritesPaper",
"fields": [{"name": "author", "type": "int"}, {"name": "paper", "type": "int"}],
}
schema_edge_author_afil_with_institution = {
"type": "record",
"name": "AuthorAfiliatedWithInstitution",
"fields": [
{"name": "author", "type": "int"},
{"name": "institution", "type": "int"},
],
}
schema_edge_paper_cites_paper = {
"type": "record",
"name": "PaperCitesPaper",
"fields": [{"name": "src", "type": "int"}, {"name": "dst", "type": "int"}],
}
# We will write to avro format and flush buffer every 100k rows and change the file name
class BufferedGCSAvroWriter:
def __init__(
self,
schema: dict,
gcs_bucket_path: str,
max_buffer_bytes=1.5e8, # 150MB
):
self.schema = schema
self.gcs_bucket_path = gcs_bucket_path
self.max_buffer_bytes = max_buffer_bytes
self.buffer = BytesIO()
self.file_index = 0
def flush(self):
# Reset buffer position to the beginning
self.buffer.seek(0)
# Initialize GCS client and upload the file
storage_client = storage.Client()
split_path = self.gcs_bucket_path.split("/")
bucket_name = split_path[2]
blob_name_prefix = "/".join(split_path[3:])
destination_blob_name = f"{blob_name_prefix}_{self.file_index}.avro"
print(
f"Avro file will be uploaded to gs://{bucket_name}/{destination_blob_name}"
)
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)
blob.upload_from_file(self.buffer, content_type="application/octet-stream")
print(f"Avro file uploaded to gs://{bucket_name}/{destination_blob_name}")
self.buffer = BytesIO()
self.file_index += 1
def write(self, record: dict):
# Write Avro data to the buffer using fastavro
writer(self.buffer, self.schema, [record])
if self.buffer.tell() > self.max_buffer_bytes:
self.flush()
def __write_table_to_gcs_proc(
schema: dict, edge_table: np.ndarray, gcs_bucket_path: str, proc_num: int
):
print(f"Begin writing edge table to GCS with proc_num: {proc_num}")
gcs_path_for_writer = f"{gcs_bucket_path}/{proc_num}"
avro_writer = BufferedGCSAvroWriter(
schema=schema, gcs_bucket_path=gcs_path_for_writer
)
field_names = [field_info["name"] for field_info in schema["fields"]]
print(f"Will write the following fields: {field_names}")
for edge in edge_table.T:
obj = {field_names[i]: edge[i] for i in range(len(field_names))}
avro_writer.write(obj)
print(f"Proc {proc_num} finished writing edge table to GCS.")
avro_writer.flush()
def write_edge_table_to_gcs(schema: dict, edge_table: np.ndarray, gcs_bucket_path: str):
gcs_utils = GcsUtils()
print(f"Clearing GCS path {gcs_bucket_path}")
gcs_utils.delete_files_in_bucket_dir(gcs_path=GcsUri(gcs_bucket_path))
num_procs = 10
edge_table_chunks = np.array_split(edge_table, 10, axis=1)
with multiprocessing.Pool(processes=num_procs) as pool:
pool.starmap(
__write_table_to_gcs_proc,
[
(schema, edge_table_chunk, gcs_bucket_path, i)
for i, edge_table_chunk in enumerate(edge_table_chunks)
],
)
def __write_node_table_to_gcs_proc(
schema: dict,
node_table: np.ndarray,
gcs_bucket_path: str,
proc_num: int,
enumerate_starting_from: Optional[int] = None,
):
print(f"Begin writing node table to GCS with proc_num: {proc_num}")
gcs_path_for_writer = f"{gcs_bucket_path}/{proc_num}"
avro_writer = BufferedGCSAvroWriter(
schema=schema, gcs_bucket_path=gcs_path_for_writer
)
field_names = [field_info["name"] for field_info in schema["fields"]]
print(f"Will write the following fields: {field_names}")
# If node table dim is 1, then carefully mange the enumeration
if len(node_table.shape) == 1:
node_table = node_table.reshape(-1, 1)
curr_count = enumerate_starting_from
for node in node_table:
obj: dict
if curr_count is not None:
obj = {field_names[0]: curr_count}
for i in range(1, len(field_names)):
obj[field_names[i]] = node[i - 1]
curr_count += 1
else:
obj = {field_names[i]: node[i] for i in range(len(field_names))}
avro_writer.write(obj)
print(f"Proc {proc_num} finished writing node table to GCS.")
avro_writer.flush()
def write_node_table_to_gcs(schema: dict, node_table: np.ndarray, gcs_bucket_path: str):
gcs_utils = GcsUtils()
print(f"Clearing GCS path {gcs_bucket_path}")
gcs_utils.delete_files_in_bucket_dir(gcs_path=GcsUri(gcs_bucket_path))
num_procs = 10
node_table_chunks = np.array_split(node_table, 10, axis=0)
chunk_sizes = [len(chunk) for chunk in node_table_chunks]
with multiprocessing.Pool(processes=num_procs) as pool:
pool.starmap(
__write_node_table_to_gcs_proc,
[
(schema, node_table_chunk, gcs_bucket_path, i, sum(chunk_sizes[:i]))
for i, node_table_chunk in enumerate(node_table_chunks)
],
)
class AvroToBqWriter:
@staticmethod
def write_to_bq(gcs_uri: str, table_id: str):
bq_utils = BqUtils()
bq_utils.create_or_empty_bq_table(bq_path=table_id)
result = bq_utils.load_file_to_bq(
source_path=gcs_uri,
bq_path=table_id,
source_format=bigquery.SourceFormat.AVRO,
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
)
print(f"Finished writing to BQ table {table_id}; with result: {result}")
num_rows = bq_utils.count_number_of_rows_in_bq_table(table_id)
print(f"Loaded {num_rows} rows to {table_id}")
edge_assets_to_export = [
("author_writes_paper", author_writes_paper, schema_edge_author_writes_paper),
(
"author_affiliated_with_institution",
author_affiliated_with_institution,
schema_edge_author_afil_with_institution,
),
("paper_cites_paper", paper_cites_paper, schema_edge_paper_cites_paper),
]
for asset_name, edge_table, schema in edge_assets_to_export:
asset_gcs_path = get_gcs_path_for_asset(asset_name)
asset_bq_path = get_bq_path_for_asset(asset_name)
print(f"Beggining exporting assets for {asset_name}")
write_edge_table_to_gcs(
schema=schema,
edge_table=edge_table,
gcs_bucket_path=asset_gcs_path,
)
AvroToBqWriter.write_to_bq(
gcs_uri=f"{asset_gcs_path}/*.avro",
table_id=f"{asset_bq_path}",
)
nodes_tables_to_export = [
("paper", paper_feat, schema_node_paper),
("paper_label", paper_label, schema_node_paper_label),
("paper_year", paper_year, schema_node_paper_year),
]
for asset_name, node_table, schema in nodes_tables_to_export:
asset_gcs_path = get_gcs_path_for_asset(asset_name)
print(f"Beggining exporting assets for {asset_gcs_path}")
asset_bq_path = get_bq_path_for_asset(asset_name)
print(f"Beggining exporting assets for {asset_name}")
write_node_table_to_gcs(
schema=schema,
node_table=node_table,
gcs_bucket_path=asset_gcs_path,
)
AvroToBqWriter.write_to_bq(
gcs_uri=f"{asset_gcs_path}/*.avro",
table_id=f"{asset_bq_path}",
)