Source code for gigl.src.mocking.lib.pyg_datasets_forks
"""
Our mocking logic uses public datasets like Cora and DBLP from PyG. PyG datasets are
downloaded from public sources which may not be available or rate-limit us. We thus
override the dataset classes to download the datasets from GCS buckets to avoid issues.
"""
from torch_geometric.data import extract_zip
from torch_geometric.datasets import DBLP, Planetoid
import gigl.env.dep_constants as dep_constants
from gigl.common import GcsUri, LocalUri
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.utils.file_loader import FileLoader
[docs]
unprocessed_datasets_gcs_uri = GcsUri(
f"gs://{dep_constants.GIGL_PUBLIC_BUCKET_NAME}/unprocessed_datasets/"
)
[docs]
class DBLPFromGCS(DBLP):
# The file from https://www.dropbox.com/s/yh4grpeks87ugr2/DBLP_processed.zip?dl=1 was copied to the below GCS path.
[docs]
url = GcsUri.join(unprocessed_datasets_gcs_uri, "pyg/dblp/DBLP_processed.zip")
[docs]
def download(self):
file_loader = FileLoader(project=get_resource_config().project)
local_uri = LocalUri.join(self.raw_dir, "DBLP.zip")
file_loader.load_file(file_uri_src=self.url, file_uri_dst=local_uri)
extract_zip(local_uri.uri, self.raw_dir)
[docs]
class CoraFromGCS(Planetoid):
# The files from https://github.com/kimiyoung/planetoid/tree/master/data were copied to the below GCS path.
[docs]
url = GcsUri.join(unprocessed_datasets_gcs_uri, "pyg/planetoid/")
[docs]
def download(self):
assert self.name == "Cora", "Only Cora dataset is supported"
file_loader = FileLoader(project=get_resource_config().project)
file_uri_srcs = [GcsUri.join(self.url, name) for name in self.raw_file_names]
file_uri_dsts = [
LocalUri.join(self.raw_dir, name) for name in self.raw_file_names
]
file_loader.load_files(
source_to_dest_file_uri_map=dict(zip(file_uri_srcs, file_uri_dsts))
)