(Optional) Fetch MAG240M Data into your own project#

Install reqs#

!pip install ogb

from typing import Optional
import os 
import multiprocessing
import numpy as np
from io import BytesIO
from fastavro import writer
from google.cloud import storage
from google.cloud import bigquery

from ogb.lsc import MAG240MDataset
from common import MAG240_DATASET_PATH

from gigl.common.utils.gcs import GcsUtils
from gigl.src.common.utils.bq import BqUtils
from gigl.common import GcsUri


BASE_GCS_BUCKET_PATH = "CHANGE THIS WITH YOUR GCS PATH" # WARN: CHANGE THIS WITH YOUR GCS PATH
BASE_BQ_PATH = "CHANGE THIS WITH YOUR BQ PATH" # WARN: CHANGE THIS WITH YOUR BQ PATH

def get_gcs_path_for_asset(asset_name):
    return f"{BASE_GCS_BUCKET_PATH}/{asset_name}"

def get_bq_path_for_asset(asset_name):
    return f"{BASE_BQ_PATH}_{asset_name}"
def fetch_dataset() -> MAG240MDataset:
    # !pip install -U ogb # We will pull the dataset from ogb: https://ogb.stanford.edu/docs/lsc/
    # Example of how this dataset can be used: https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/gnn.py
    # WARNING: This code block can take hours to even days to run if the dataset is not available already locally
    # This is likely because upstream servers are slow as even on cloud dinstances with 100GB/s + network, 
    # the download is slow and took half a day.
    print("Fetching MAG240M dataset and storing it in ", MAG240_DATASET_PATH)
    if not os.path.exists(MAG240_DATASET_PATH):
        os.makedirs(MAG240_DATASET_PATH)
    dataset = MAG240MDataset(root = MAG240_DATASET_PATH)
    return dataset

dataset = fetch_dataset()
author_writes_paper = dataset.edge_index('author', 'paper')
author_affiliated_with_institution = dataset.edge_index('author', 'institution')
paper_cites_paper = dataset.edge_index('paper', 'paper')

paper_feat = dataset.all_paper_feat
paper_label = dataset.all_paper_label
paper_year = dataset.all_paper_year
# Define all the Avro schemas
fields_node_paper = [
    {"name": "paper", "type": "int"},
]
for featIdx in range(dataset.num_paper_features):
    fields_node_paper.append({"name": f"feat_{featIdx}", "type": "float"})

schema_node_paper = {
    "type": "record",
    "name": "Paper",
    "fields": fields_node_paper
}

schema_node_paper_label = {
    "type": "record",
    "name": "PaperLabel",
    "fields": [
        {"name": "paper", "type": "int"},
        {"name": "label", "type": "float"}
    ]
}

schema_node_paper_year = {
    "type": "record",
    "name": "PaperYear",
    "fields": [
        {"name": "paper", "type": "int"},
        {"name": "year", "type": "int"}
    ]
}

schema_edge_author_writes_paper = {
    "type": "record",
    "name": "AuthorWritesPaper",
    "fields": [
        {"name": "author", "type": "int"},
        {"name": "paper", "type": "int"}
    ]
}

schema_edge_author_afil_with_institution = {
    "type": "record",
    "name": "AuthorAfiliatedWithInstitution",
    "fields": [
        {"name": "author", "type": "int"},
        {"name": "institution", "type": "int"}
    ]
}

schema_edge_paper_cites_paper = {
    "type": "record",
    "name": "PaperCitesPaper",
    "fields": [
        {"name": "src", "type": "int"},
        {"name": "dst", "type": "int"}
    ]
}
# We will write to avro format and flush buffer every 100k rows and change the file name
class BufferedGCSAvroWriter:
    def __init__(
            self, 
            schema: dict, 
            gcs_bucket_path: str, 
            max_buffer_bytes=1.5e+8, # 150MB
        ):
        self.schema = schema
        self.gcs_bucket_path = gcs_bucket_path
        self.max_buffer_bytes = max_buffer_bytes
        self.buffer = BytesIO()
        self.file_index = 0

        
    def flush(self):
        # Reset buffer position to the beginning
        self.buffer.seek(0)
        # Initialize GCS client and upload the file
        storage_client = storage.Client()
        
        split_path = self.gcs_bucket_path.split("/")
        bucket_name = split_path[2]
        blob_name_prefix = "/".join(split_path[3:])
        destination_blob_name = f"{blob_name_prefix}_{self.file_index}.avro"
        print(f"Avro file will be uploaded to gs://{bucket_name}/{destination_blob_name}")

        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(destination_blob_name)
        blob.upload_from_file(self.buffer, content_type="application/octet-stream")
        print(f"Avro file uploaded to gs://{bucket_name}/{destination_blob_name}")
        self.buffer = BytesIO()
        self.file_index += 1
    
    def write(self, record: dict):
        # Write Avro data to the buffer using fastavro
        writer(self.buffer, self.schema, [record])
        if self.buffer.tell() > self.max_buffer_bytes:
            self.flush()

def __write_table_to_gcs_proc(schema: dict, edge_table: np.ndarray, gcs_bucket_path: str, proc_num: int):
    print(f"Begin writing edge table to GCS with proc_num: {proc_num}")
    gcs_path_for_writer = f"{gcs_bucket_path}/{proc_num}"
    avro_writer = BufferedGCSAvroWriter(schema=schema, gcs_bucket_path=gcs_path_for_writer)
    field_names = [field_info["name"] for field_info in schema["fields"]]
    print(f"Will write the following fields: {field_names}")
    
    for edge in edge_table.T:
        obj = {field_names[i]: edge[i] for i in range(len(field_names))}
        avro_writer.write(obj)
    print(f"Proc {proc_num} finished writing edge table to GCS.")
    avro_writer.flush()

def write_edge_table_to_gcs(schema: dict, edge_table: np.ndarray, gcs_bucket_path: str):
    gcs_utils = GcsUtils()
    print(f"Clearing GCS path {gcs_bucket_path}")
    gcs_utils.delete_files_in_bucket_dir(gcs_path = GcsUri(gcs_bucket_path))
    num_procs = 10
    edge_table_chunks = np.array_split(edge_table, 10, axis=1)
    with multiprocessing.Pool(processes=num_procs) as pool:
        pool.starmap(__write_table_to_gcs_proc, [
            (schema, edge_table_chunk, gcs_bucket_path, i)
            for i, edge_table_chunk in enumerate(edge_table_chunks)
        ])


def __write_node_table_to_gcs_proc(schema: dict, node_table: np.ndarray, gcs_bucket_path: str, proc_num: int, enumerate_starting_from: Optional[int] = None):
    print(f"Begin writing node table to GCS with proc_num: {proc_num}")
    gcs_path_for_writer = f"{gcs_bucket_path}/{proc_num}"
    avro_writer = BufferedGCSAvroWriter(schema=schema, gcs_bucket_path=gcs_path_for_writer)
    field_names = [field_info["name"] for field_info in schema["fields"]]
    print(f"Will write the following fields: {field_names}")
    # If node table dim is 1, then carefully mange the enumeration
    if len(node_table.shape) == 1:
        node_table = node_table.reshape(-1, 1)
    curr_count = enumerate_starting_from
    for node in node_table:
        obj: dict
        if curr_count is not None:
            obj = {field_names[0]: curr_count}
            for i in range(1, len(field_names)):
                obj[field_names[i]] = node[i-1]
            curr_count += 1
        else:
            obj = {field_names[i]: node[i] for i in range(len(field_names))}

        avro_writer.write(obj)
    print(f"Proc {proc_num} finished writing node table to GCS.")
    avro_writer.flush()

def write_node_table_to_gcs(schema: dict, node_table: np.ndarray, gcs_bucket_path: str):
    gcs_utils = GcsUtils()
    print(f"Clearing GCS path {gcs_bucket_path}")
    gcs_utils.delete_files_in_bucket_dir(gcs_path = GcsUri(gcs_bucket_path))
    num_procs = 10
    node_table_chunks = np.array_split(node_table, 10, axis=0)
    chunk_sizes = [len(chunk) for chunk in node_table_chunks]
    with multiprocessing.Pool(processes=num_procs) as pool:
        pool.starmap(__write_node_table_to_gcs_proc, [
            (schema, node_table_chunk, gcs_bucket_path, i, sum(chunk_sizes[:i]))
            for i, node_table_chunk in enumerate(node_table_chunks)
        ])

    


class AvroToBqWriter:
    @staticmethod
    def write_to_bq(gcs_uri: str, table_id: str):
        bq_utils = BqUtils()
        bq_utils.create_or_empty_bq_table(bq_path=table_id)
        result = bq_utils.load_file_to_bq(
            source_path=gcs_uri,
            bq_path=table_id,
            source_format=bigquery.SourceFormat.AVRO,
            write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
        )
        print(f"Finished writing to BQ table {table_id}; with result: {result}")

        num_rows = bq_utils.count_number_of_rows_in_bq_table(table_id)
        print(f"Loaded {num_rows} rows to {table_id}")
edge_assets_to_export = [
    ("author_writes_paper", author_writes_paper, schema_edge_author_writes_paper),
    ("author_affiliated_with_institution", author_affiliated_with_institution, schema_edge_author_afil_with_institution),
    ("paper_cites_paper", paper_cites_paper, schema_edge_paper_cites_paper),
    
]

for asset_name, edge_table, schema in edge_assets_to_export:
    asset_gcs_path = get_gcs_path_for_asset(asset_name)
    asset_bq_path = get_bq_path_for_asset(asset_name)
    print(f"Beggining exporting assets for {asset_name}")
    write_edge_table_to_gcs(
        schema=schema,
        edge_table=edge_table,
        gcs_bucket_path=asset_gcs_path,
    )
    AvroToBqWriter.write_to_bq(
        gcs_uri=f"{asset_gcs_path}/*.avro",
        table_id=f"{asset_bq_path}",
    )
nodes_tables_to_export = [
    ("paper", paper_feat, schema_node_paper),
    ("paper_label", paper_label, schema_node_paper_label),
    ("paper_year", paper_year, schema_node_paper_year),
]

for asset_name, node_table, schema in nodes_tables_to_export:
    asset_gcs_path = get_gcs_path_for_asset(asset_name)
    print(f"Beggining exporting assets for {asset_gcs_path}")
    asset_bq_path = get_bq_path_for_asset(asset_name)
    print(f"Beggining exporting assets for {asset_name}")
    write_node_table_to_gcs(
        schema=schema,
        node_table=node_table,
        gcs_bucket_path=asset_gcs_path,
    )
    AvroToBqWriter.write_to_bq(
        gcs_uri=f"{asset_gcs_path}/*.avro",
        table_id=f"{asset_bq_path}",
    )