(Optional) Fetch MAG240M Data into your own project#

Latest version of this notebook can be found on github

Install reqs#

!pip install ogb

import multiprocessing
import os
from io import BytesIO
from typing import Optional

import numpy as np
from common import MAG240_DATASET_PATH
from fastavro import writer
from google.cloud import bigquery, storage
from ogb.lsc import MAG240MDataset

from gigl.common import GcsUri
from gigl.common.utils.gcs import GcsUtils
from gigl.src.common.utils.bq import BqUtils

BASE_GCS_BUCKET_PATH = (
    "CHANGE THIS WITH YOUR GCS PATH"  # WARN: CHANGE THIS WITH YOUR GCS PATH
)
BASE_BQ_PATH = "CHANGE THIS WITH YOUR BQ PATH"  # WARN: CHANGE THIS WITH YOUR BQ PATH


def get_gcs_path_for_asset(asset_name):
    return f"{BASE_GCS_BUCKET_PATH}/{asset_name}"


def get_bq_path_for_asset(asset_name):
    return f"{BASE_BQ_PATH}_{asset_name}"
def fetch_dataset() -> MAG240MDataset:
    # !pip install -U ogb # We will pull the dataset from ogb: https://ogb.stanford.edu/docs/lsc/
    # Example of how this dataset can be used: https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/gnn.py
    # WARNING: This code block can take hours to even days to run if the dataset is not available already locally
    # This is likely because upstream servers are slow as even on cloud dinstances with 100GB/s + network,
    # the download is slow and took half a day.
    print("Fetching MAG240M dataset and storing it in ", MAG240_DATASET_PATH)
    if not os.path.exists(MAG240_DATASET_PATH):
        os.makedirs(MAG240_DATASET_PATH)
    dataset = MAG240MDataset(root=MAG240_DATASET_PATH)
    return dataset


dataset = fetch_dataset()
author_writes_paper = dataset.edge_index("author", "paper")
author_affiliated_with_institution = dataset.edge_index("author", "institution")
paper_cites_paper = dataset.edge_index("paper", "paper")

paper_feat = dataset.all_paper_feat
paper_label = dataset.all_paper_label
paper_year = dataset.all_paper_year
# Define all the Avro schemas
fields_node_paper = [
    {"name": "paper", "type": "int"},
]
for featIdx in range(dataset.num_paper_features):
    fields_node_paper.append({"name": f"feat_{featIdx}", "type": "float"})

schema_node_paper = {"type": "record", "name": "Paper", "fields": fields_node_paper}

schema_node_paper_label = {
    "type": "record",
    "name": "PaperLabel",
    "fields": [{"name": "paper", "type": "int"}, {"name": "label", "type": "float"}],
}

schema_node_paper_year = {
    "type": "record",
    "name": "PaperYear",
    "fields": [{"name": "paper", "type": "int"}, {"name": "year", "type": "int"}],
}

schema_edge_author_writes_paper = {
    "type": "record",
    "name": "AuthorWritesPaper",
    "fields": [{"name": "author", "type": "int"}, {"name": "paper", "type": "int"}],
}

schema_edge_author_afil_with_institution = {
    "type": "record",
    "name": "AuthorAfiliatedWithInstitution",
    "fields": [
        {"name": "author", "type": "int"},
        {"name": "institution", "type": "int"},
    ],
}

schema_edge_paper_cites_paper = {
    "type": "record",
    "name": "PaperCitesPaper",
    "fields": [{"name": "src", "type": "int"}, {"name": "dst", "type": "int"}],
}
# We will write to avro format and flush buffer every 100k rows and change the file name
class BufferedGCSAvroWriter:
    def __init__(
        self,
        schema: dict,
        gcs_bucket_path: str,
        max_buffer_bytes=1.5e8,  # 150MB
    ):
        self.schema = schema
        self.gcs_bucket_path = gcs_bucket_path
        self.max_buffer_bytes = max_buffer_bytes
        self.buffer = BytesIO()
        self.file_index = 0

    def flush(self):
        # Reset buffer position to the beginning
        self.buffer.seek(0)
        # Initialize GCS client and upload the file
        storage_client = storage.Client()

        split_path = self.gcs_bucket_path.split("/")
        bucket_name = split_path[2]
        blob_name_prefix = "/".join(split_path[3:])
        destination_blob_name = f"{blob_name_prefix}_{self.file_index}.avro"
        print(
            f"Avro file will be uploaded to gs://{bucket_name}/{destination_blob_name}"
        )

        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(destination_blob_name)
        blob.upload_from_file(self.buffer, content_type="application/octet-stream")
        print(f"Avro file uploaded to gs://{bucket_name}/{destination_blob_name}")
        self.buffer = BytesIO()
        self.file_index += 1

    def write(self, record: dict):
        # Write Avro data to the buffer using fastavro
        writer(self.buffer, self.schema, [record])
        if self.buffer.tell() > self.max_buffer_bytes:
            self.flush()


def __write_table_to_gcs_proc(
    schema: dict, edge_table: np.ndarray, gcs_bucket_path: str, proc_num: int
):
    print(f"Begin writing edge table to GCS with proc_num: {proc_num}")
    gcs_path_for_writer = f"{gcs_bucket_path}/{proc_num}"
    avro_writer = BufferedGCSAvroWriter(
        schema=schema, gcs_bucket_path=gcs_path_for_writer
    )
    field_names = [field_info["name"] for field_info in schema["fields"]]
    print(f"Will write the following fields: {field_names}")

    for edge in edge_table.T:
        obj = {field_names[i]: edge[i] for i in range(len(field_names))}
        avro_writer.write(obj)
    print(f"Proc {proc_num} finished writing edge table to GCS.")
    avro_writer.flush()


def write_edge_table_to_gcs(schema: dict, edge_table: np.ndarray, gcs_bucket_path: str):
    gcs_utils = GcsUtils()
    print(f"Clearing GCS path {gcs_bucket_path}")
    gcs_utils.delete_files_in_bucket_dir(gcs_path=GcsUri(gcs_bucket_path))
    num_procs = 10
    edge_table_chunks = np.array_split(edge_table, 10, axis=1)
    with multiprocessing.Pool(processes=num_procs) as pool:
        pool.starmap(
            __write_table_to_gcs_proc,
            [
                (schema, edge_table_chunk, gcs_bucket_path, i)
                for i, edge_table_chunk in enumerate(edge_table_chunks)
            ],
        )


def __write_node_table_to_gcs_proc(
    schema: dict,
    node_table: np.ndarray,
    gcs_bucket_path: str,
    proc_num: int,
    enumerate_starting_from: Optional[int] = None,
):
    print(f"Begin writing node table to GCS with proc_num: {proc_num}")
    gcs_path_for_writer = f"{gcs_bucket_path}/{proc_num}"
    avro_writer = BufferedGCSAvroWriter(
        schema=schema, gcs_bucket_path=gcs_path_for_writer
    )
    field_names = [field_info["name"] for field_info in schema["fields"]]
    print(f"Will write the following fields: {field_names}")
    # If node table dim is 1, then carefully mange the enumeration
    if len(node_table.shape) == 1:
        node_table = node_table.reshape(-1, 1)
    curr_count = enumerate_starting_from
    for node in node_table:
        obj: dict
        if curr_count is not None:
            obj = {field_names[0]: curr_count}
            for i in range(1, len(field_names)):
                obj[field_names[i]] = node[i - 1]
            curr_count += 1
        else:
            obj = {field_names[i]: node[i] for i in range(len(field_names))}

        avro_writer.write(obj)
    print(f"Proc {proc_num} finished writing node table to GCS.")
    avro_writer.flush()


def write_node_table_to_gcs(schema: dict, node_table: np.ndarray, gcs_bucket_path: str):
    gcs_utils = GcsUtils()
    print(f"Clearing GCS path {gcs_bucket_path}")
    gcs_utils.delete_files_in_bucket_dir(gcs_path=GcsUri(gcs_bucket_path))
    num_procs = 10
    node_table_chunks = np.array_split(node_table, 10, axis=0)
    chunk_sizes = [len(chunk) for chunk in node_table_chunks]
    with multiprocessing.Pool(processes=num_procs) as pool:
        pool.starmap(
            __write_node_table_to_gcs_proc,
            [
                (schema, node_table_chunk, gcs_bucket_path, i, sum(chunk_sizes[:i]))
                for i, node_table_chunk in enumerate(node_table_chunks)
            ],
        )


class AvroToBqWriter:
    @staticmethod
    def write_to_bq(gcs_uri: str, table_id: str):
        bq_utils = BqUtils()
        bq_utils.create_or_empty_bq_table(bq_path=table_id)
        result = bq_utils.load_file_to_bq(
            source_path=gcs_uri,
            bq_path=table_id,
            source_format=bigquery.SourceFormat.AVRO,
            write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
        )
        print(f"Finished writing to BQ table {table_id}; with result: {result}")

        num_rows = bq_utils.count_number_of_rows_in_bq_table(table_id)
        print(f"Loaded {num_rows} rows to {table_id}")
edge_assets_to_export = [
    ("author_writes_paper", author_writes_paper, schema_edge_author_writes_paper),
    (
        "author_affiliated_with_institution",
        author_affiliated_with_institution,
        schema_edge_author_afil_with_institution,
    ),
    ("paper_cites_paper", paper_cites_paper, schema_edge_paper_cites_paper),
]

for asset_name, edge_table, schema in edge_assets_to_export:
    asset_gcs_path = get_gcs_path_for_asset(asset_name)
    asset_bq_path = get_bq_path_for_asset(asset_name)
    print(f"Beggining exporting assets for {asset_name}")
    write_edge_table_to_gcs(
        schema=schema,
        edge_table=edge_table,
        gcs_bucket_path=asset_gcs_path,
    )
    AvroToBqWriter.write_to_bq(
        gcs_uri=f"{asset_gcs_path}/*.avro",
        table_id=f"{asset_bq_path}",
    )
nodes_tables_to_export = [
    ("paper", paper_feat, schema_node_paper),
    ("paper_label", paper_label, schema_node_paper_label),
    ("paper_year", paper_year, schema_node_paper_year),
]

for asset_name, node_table, schema in nodes_tables_to_export:
    asset_gcs_path = get_gcs_path_for_asset(asset_name)
    print(f"Beggining exporting assets for {asset_gcs_path}")
    asset_bq_path = get_bq_path_for_asset(asset_name)
    print(f"Beggining exporting assets for {asset_name}")
    write_node_table_to_gcs(
        schema=schema,
        node_table=node_table,
        gcs_bucket_path=asset_gcs_path,
    )
    AvroToBqWriter.write_to_bq(
        gcs_uri=f"{asset_gcs_path}/*.avro",
        table_id=f"{asset_bq_path}",
    )