%load_ext autoreload
%autoreload 2

from gigl.common.utils.jupyter_magics import change_working_dir_to_gigl_root
change_working_dir_to_gigl_root()

Cora Distributed Training Example#

This notebook will walk you through how to use GiGL to train a model on the CORA dataset in a distributed fashion. At the end of this notebook you will have:

  1. Preprocessed the CORA dataset and saved it as TFRecord files to GCS

  2. Trained a model based on the Graph, across multiple machines using Torch Distributed constructs (DDP)

  3. Performed inference on the trained model, saving the resulting embeddings to BigQuery

If you are more interested in the fine details of individual components, or the GBML protos, please see toy_example_walkthrough.ipynb which provides in-depth explanations of what each component is doing.

The latest version of this notebook can be found on github

NOTE: This notebook and dblp.ipynb are very similar, and differ in the following ways:

  • The TEMPLATE_TASK_CONFIG_URIs are using different task specs

  • The Examining the trained model cells use homogenoeus/hetergeneous models and PyG constructs, as appropirate for the dataset.

  • The # Looking at inference results cells are for homogeneous/hetereogenous inference results (e.g. if there are multiple BQ tables for the different node types).

Setting up GCP Project and configs#

Assuming you have a GCP project setup:

  1. Open up configs/example_resource_config.yaml and fill all relevant fields under common_compute_config:

  • project

  • region

  • temp_assets_bucket

  • temp_regional_assets_bucket

  • perm_assets_bucket

  • temp_assets_bq_dataset_name

  • embedding_bq_dataset_name

  • gcp_service_account_email

  1. Ensure your service account has relevant perms. See our cloud setup guide

import datetime
import getpass
import os

from gigl.common import LocalUri
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.types.pb_wrappers.gigl_resource_config import GiglResourceConfigWrapper
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper

curr_datetime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

# Firstly, let's give your job a name and ensure that the resource and task configs exist and can be loaded
JOB_NAME = f"{getpass.getuser()}_gigl_cora_{curr_datetime}"
TEMPLATE_TASK_CONFIG_URI = LocalUri("examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml")

# Respect the environment variable for resource config URI
# if not, set it to some default value.
RESOURCE_CONFIG_URI = LocalUri(os.environ.get("GIGL_TEST_DEFAULT_RESOURCE_CONFIG", "examples/link_prediction/configs/example_resource_config.yaml"))
print(f"Using resource config URI: {RESOURCE_CONFIG_URI}")

TEMPLATE_TASK_CONFIG: GbmlConfigPbWrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(gbml_config_uri=TEMPLATE_TASK_CONFIG_URI)
RESOURCE_CONFIG: GiglResourceConfigWrapper = get_resource_config(resource_config_uri=RESOURCE_CONFIG_URI)
PROJECT = RESOURCE_CONFIG.project


print(f"Succesfully found task config and resource config. Script will help execute job: {JOB_NAME} on project: {PROJECT}")
# Lets run some basic checks to validate correctness of the task and resource config
from gigl.src.validation_check.config_validator import kfp_validation_checks

kfp_validation_checks(
    job_name=JOB_NAME,
    task_config_uri=TEMPLATE_TASK_CONFIG_URI,
    resource_config_uri=RESOURCE_CONFIG_URI,
    # config_populator is the first step in the pipeline; where we will populat the template task config specified above and generate a frozen config
    start_at="config_populator"
)

Compiling Src Docker images#

You will need to build and push docker images with your custom code so that individual GiGL components can leverage your code. For this experiment we will consider the CORA specs and code to be “custom code”, and we will guide you how to build a docker image with the code.

We will make use of scripts/build_and_push_docker_image.py for this.

Make note that this builds containers/Dockerfile.src and containers/Dockerfile.dataflow.src; which have instructions to COPY the examples folder - which contains all the source code for CORA, and it has all the GiGL src code.

from concurrent.futures import ThreadPoolExecutor
from scripts.build_and_push_docker_image import build_and_push_cpu_image, build_and_push_cuda_image, build_and_push_dataflow_image

DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG = f"us-central1-docker.pkg.dev/{PROJECT}/gigl-base-images/gigl_dataflow_runtime:{curr_datetime}"
DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG = f"us-central1-docker.pkg.dev/{PROJECT}/gigl-base-images/gigl_cuda:{curr_datetime}"
DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG = f"us-central1-docker.pkg.dev/{PROJECT}/gigl-base-images/gigl_cpu:{curr_datetime}"

with ThreadPoolExecutor(max_workers=3) as executor:
    executor.submit(
        build_and_push_dataflow_image,
        image_name=DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG,
    )
    executor.submit(
        build_and_push_cuda_image,
        image_name=DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG,
    )
    executor.submit(
        build_and_push_cpu_image,
        image_name=DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG,
    )

print(f"""We built and pushed the following docker images:
- {DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG}
- {DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG}
- {DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG}
""")

We will instantiate local runner to help orchestrate the test pipeline#

from gigl.orchestration.local.runner import Runner, PipelineConfig


runner = Runner()
pipeline_config = PipelineConfig(
    applied_task_identifier=JOB_NAME,
    task_config_uri=TEMPLATE_TASK_CONFIG_URI,
    resource_config_uri=RESOURCE_CONFIG_URI,
    custom_cuda_docker_uri=DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG,
    custom_cpu_docker_uri=DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG,
    dataflow_docker_uri=DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG,
)

First we will run config populator#

The config populator takes in a template GbmlConfig and outputs a frozen GbmlConfig by populating all job related metadata paths in sharedConfig. These are mostly GCS paths which the following components read and write from, and use as an intermediary data communication medium. For example, the field sharedConfig.trainedModelMetadata is populated with a GCS URI, which indicates to the Trainer to write the trained model to this path, and to the Inferencer to read the model from this path

from gigl.src.common.utils.file_loader import FileLoader
frozen_config_uri = runner.run_config_populator(pipeline_config=pipeline_config)
frozen_config = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(gbml_config_uri=frozen_config_uri)
file_loader = FileLoader()

print(f"Config Populator has successfully generated the following frozen config from the template ({TEMPLATE_TASK_CONFIG_URI}) :")
print(frozen_config.gbml_config_pb)

pipeline_config.task_config_uri = frozen_config_uri # We need to update the task config uri to the new frozen config uri

Visualizing the Diff Between Template and Frozen Config#

We now have a frozen task config, with the path specified by FROZEN_TASK_CONFIG_PATH. We visualize the diff between the frozen_task_config generated by the config_populator and the original template_task_config. All the code below is just to do that and has nothing to do with GiGL.

Specifically, note that:

  1. The component added sharedConfig to the YAML, which contains all the intermediary and final output paths for each component.

  2. It also added a condensedEdgeTypeMap and a condensedNodeTypeMap, which map all provided edge types and node types to int to save storage space:

    • EdgeType: Tuple[srcNodeType: str, relation: str, dstNodeType: str)] -> int, and

    • NodeType: str -> int

    • Note: You may also provide your own condensedMaps; they will be generated for you if not provided.

from gigl.common.utils.jupyter_magics import show_task_config_colored_unified_diff

show_task_config_colored_unified_diff(
    f1_uri=frozen_config_uri,
    f2_uri=TEMPLATE_TASK_CONFIG_URI,
    f1_name='frozen_task_config.yaml',
    f2_name='template_task_config.yaml'
)

Next we run the preprocessor#

The Data Preprocessor reads node, edge and respective feature data from a data source, and produces preprocessed / transformed versions of all this data, for subsequent components to use. It uses Tensorflow Transform to achieve data transformation in a distributed fashion, and allows for transformations like categorical encoding, scaling, normalization, casting and more.

In this case we are using preprocessing spec defined in python/gigl/src/mocking/mocking_assets/passthrough_preprocessor_config_for_mocked_assets.py - take a look for more details.

You will note that the preprocessor will create a few BQ jobs to prepare the node and edge tables, subsequently it will kick off TFT (dataflow) jobs to do the actual preprocessing. The preprocessor will: (1) create a preprocessing spec and dump it to path specified in frozen config sharedConfig.preprocessedMetadataUri. (2) Respective Dataflow jobs will dump the preprocessed assets as .tfrecord files to the paths specified inside the preprocessing spec preprocessedMetadataUri

The preprocessor will also enumerate all node ids, remapping the node ids as integers. See the preprocessor docs for more information.

# WARN: There is an issue when trying to run dataflow jobs from inside a jupyter kernel; thus we cannot use the line 
# below to run the preprocessor as you would normally in a python script.
# runner.run_data_preprocessor(pipeline_config=pipeline_config) 

# Instead, we will run the preprocessor from the command line.
# Note: You can actually do this with every component; we just make use of the runner to make it easier to run the components.
!python -m gigl.src.data_preprocessor.data_preprocessor \
--job_name=$JOB_NAME \
--task_config_uri=$frozen_config_uri \
--resource_config_uri=$RESOURCE_CONFIG_URI \
--custom_worker_image_uri=$DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG
## Preprocesor outputs.
from snapchat.research.gbml.preprocessed_metadata_pb2 import PreprocessedMetadata

print(frozen_config.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri)
# Reload frozen config as config populator has modified the file.
frozen_config = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(gbml_config_uri=frozen_config_uri)

# Let's see what the preprocessor has outputted
print(frozen_config.preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb)
# The feature keys can make the message a bit hard to read, so let's filter them out.
filtered_preprocessed_metadata = PreprocessedMetadata()
filtered_preprocessed_metadata.CopyFrom(frozen_config.preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb)
for node_type in filtered_preprocessed_metadata.condensed_node_type_to_preprocessed_metadata:
    filtered_preprocessed_metadata.condensed_node_type_to_preprocessed_metadata[node_type].ClearField("feature_keys")
for edge_type in filtered_preprocessed_metadata.condensed_edge_type_to_preprocessed_metadata:
    filtered_preprocessed_metadata.condensed_edge_type_to_preprocessed_metadata[edge_type].main_edge_info.ClearField("feature_keys")
print("More readable preprocessed metadata:")
print(filtered_preprocessed_metadata)

Training the model#

The Trainer component reads the pre-processed graphs stored as TFRecords on GCS (whose paths are specified in the frozen config), and trains a GNN model on the training set, early stops on the performance of the validation set, and finally evaluates on the test set. The training logic is implemented with PyTorch Distributed Data Parallel (DDP) Training, which enables distributed training on multiple GPU cards across multiple worker nodes.

The trainer reads the graph data which are stored as TFRecords on GCS, whose locations are stored at GbmlConfig.SharedConfig.preprocessed_metadata_uri.

Once the model is trained, the model weights will be saved to the URI located at GbmlConfig.SharedConfig.TrainedModelMetadata.trained_model_uri

runner.run_trainer(pipeline_config=pipeline_config)
### Examining the trained model
# The trained model will be saved at: `GbmlConfig.SharedConfig.TrainedModelMetadata.trained_model_uri`
print(frozen_config.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri)

# You can load the model locally and play around with it:
import torch
from torch_geometric.data import Data

from examples.link_prediction.models import init_example_gigl_homogeneous_model
from gigl.common import UriFactory
from gigl.src.common.utils.model import load_state_dict_from_uri


node_type = frozen_config.graph_metadata_pb_wrapper.homogeneous_condensed_node_type
edge_type = frozen_config.graph_metadata_pb_wrapper.homogeneous_condensed_edge_type
node_feature_dim = frozen_config.preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map[node_type]
edge_feature_dim = frozen_config.preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map[edge_type]
model = init_example_gigl_homogeneous_model(
    node_feature_dim=node_feature_dim,
    edge_feature_dim=edge_feature_dim,
    device=torch.device("cpu"),
    state_dict=load_state_dict_from_uri(UriFactory.create_uri(frozen_config.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri))
)
print(model)

# Create some random data to test the model.
example_data = Data(x=torch.rand((10, node_feature_dim)), edge_index=torch.randint(0, 10, (2, 20)))
embeddings = model(example_data, device=torch.device("cpu"))
print(embeddings)

Inference#

The Inferencer component is responsible for running inference of a trained model on samples generated on the fly during live subgraph sampling. At a high level, it works by applying a trained model in an embarrassingly parallel and distributed fashion across these samples, and persisting the output embeddings and/or predictions.

runner.run_inferencer(
    pipeline_config=pipeline_config,
)

Post Processor#

The inferencer outputs embeddings for enumerated node ids, e.g. the node ids that the preprocessor outputs. The preprocessor stores a mappings between unenumerated node ids and enumnerated node ids in PreprocessedMetadata.condensed_node_type_to_preprocessed_metadata[node_type].enumerated_node_ids_bq_table 1

The postprocessor unemuerates the embeddings and stores them in PreprocessedMetadata.condensed_node_type_to_preprocessed_metadata[node_type].enumerated_node_ids_bq_table

# WARN: There is an issue when trying to run dataflow jobs from inside a jupyter kernel; thus we cannot use the line 
# below to run the postprocessor as you would normally in a python script.

# Instead, we will run the preprocessor from the command line.
# Note: You can actually do this with every component; we just make use of the runner to make it easier to run the components.
!python -m gigl.src.post_process.post_processor \
--job_name=$JOB_NAME \
--task_config_uri=$frozen_config_uri \
--resource_config_uri=$RESOURCE_CONFIG_URI 
# Looking at inference results
# Note we need to do this *after* we run the post-processor.
# As in the live-subgraph sampling world, the post processor unenumerates the node ids per the mapping in
# `PreprocessedMetadata.condensed_node_type_to_preprocessed_metadata[node_type].enumerated_node_ids_bq_table`
# and stores the resulting embeddings in the BQ table specified in the frozen config.

from gigl.src.common.utils.bq import BqUtils

bq_emb_out_table = frozen_config.shared_config.inference_metadata.node_type_to_inferencer_output_info_map[frozen_config.graph_metadata_pb_wrapper.homogeneous_node_type].embeddings_path
print(f"Embeddings should be successfully stored in the following location: {bq_emb_out_table}")

bq_utils = BqUtils(project=PROJECT)
query = f"SELECT * FROM {bq_emb_out_table} LIMIT 5"
result = list(bq_utils.run_query(query=query, labels={}))

print(f"Query result: {result}")