%load_ext autoreload
%autoreload 2
from gigl.common.utils.jupyter_magics import change_working_dir_to_gigl_root
change_working_dir_to_gigl_root()
Changed working directory to: /home/svij/GiGL
Toy Example#
This notebook is a walkthrough off training and inferencing on a small toy graph using GiGL
Overview Of Components#
This notebook shows the process of a simple, human-digestable graph being passed through all the pipeline components in GiGL in preperation for training to help understand how each of the components work.
The pipeline consists of the following components:
Config Populator: Takes a template config and creates a frozen workflow config that dictates all inputs/outputs and business parameters that is read and used by each subsequent component.
input: template_config.yaml
output: frozen_gbml_config.yaml
Data Preprocesser: Transforms necessary node and edge feature assets as needed as a precursor step in most ML tasks according to user provided data preprocessor config class
input: frozen_gbml_config.yaml which includes user-defined preprocessor class for custom logic and custom arguments can be passed under dataPreprocessorArgs
output: PreprocessedMetadata Proto which includes inferred GraphMetadata and preproccessed graph data Tfrecords after applying user defined preprocessing function
Subgraph Sampler: Samples k-hop subgraphs for each node according to user provided arguments
input: frozen_gbml_config.yaml, resource_config.yaml
output: Subgraph Samples (tfrecord format based on predefined schema in protos) are stored in the uri defined in flattenedGraphMetadata field.
Split Generator: Splits subgraph sampler outputs into train/test/val sets according to user provided split strategy class.
input: frozen_gbml_config.yaml which includes instance of SplitStrategy and an instance of Assigner
output: TFRecord samples
Trainer: The trainer component reads the output of split generator and trains a model on the training set, stops based on validation set, and evaluates on the test set
input: frozen_gbml_config.yaml
output: state_dict stored in trainedModelUri
Inferencer: Runs inference of a trained model on samples generated by Subgraph Sampler.
input: frozen_gbml_config.yaml
output: Embeddings and/or prediction assets
Input Graph#
We use the input graph defined in examples/toy_visual_example/graph_config.yaml. You are welcome to change this file to a custom graph off your own choosing.
Visualizing the input graph#
from torch_geometric.data import HeteroData
from examples.toy_visual_example.visualize import GraphVisualizer
from gigl.src.mocking.toy_asset_mocker import load_toy_graph
original_graph_heterodata: HeteroData = load_toy_graph(graph_config="examples/toy_visual_example/graph_config.yaml")
# Visualize the graph
GraphVisualizer.visualize_graph(original_graph_heterodata)

Setting up Configs#
The first thing we will need to do is create the resource and task configs.
Task Config: Specifies task-related configurations, guiding the behavior of components according to the needs of your machine learning task. See Task Config Guide. For this task, we have already provided a task config: task_config.yaml
Resource Config: Details the resource allocation and environmental settings across all GiGL components. This encompasses shared resources for all components, as well as component-specific settings. See Resource Config Guide. For this task we provide a resource resource_config.yaml. Although, the provided default values in
shared_resource_config.common_compute_config
are for resources you will not have access to unless you are a core contributor.Intructions to configure the resource config to work: If you have not already, please follow the Quick Start Guide to setup your cloud environment and setup a default test resource config. You can then copy the relevant
shared_resource_config.common_compute_config
to resource_config.yaml
import os
import pathlib
import textwrap
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from gigl.common import Uri, UriFactory
notebook_dir = pathlib.Path("./examples/toy_visual_example").as_posix() # We should be in root dir because of cell # 1
# Add the root gigl dir to the Python path so `example` folder can be imported as a module.
# You are welcome to customize these to point to your own configuration files.
JOB_NAME = "gigl_test_job"
TEMPLATE_TASK_CONFIG_PATH: Uri = UriFactory.create_uri(f"{notebook_dir}/template_task_config.yaml")
FROZEN_TASK_CONFIG_POINTER_FILE_PATH: Uri = UriFactory.create_uri(f"/tmp/GiGL/{JOB_NAME}/frozen_task_config.yaml")
pathlib.Path(FROZEN_TASK_CONFIG_POINTER_FILE_PATH.uri).parent.mkdir(parents=True, exist_ok=True)
# Ensure you change the resource config path to point to your own resource configuration
# i.e. what was exported to $GIGL_TEST_DEFAULT_RESOURCE_CONFIG as part of the quick start guide.
RESOURCE_CONFIG_PATH: Uri = UriFactory.create_uri(f"{notebook_dir}/resource_config.yaml")
# Export string format of the uris so we can reference them in cells that execute bash commands below.
os.environ["TEMPLATE_TASK_CONFIG_PATH"] = TEMPLATE_TASK_CONFIG_PATH.uri
os.environ["FROZEN_TASK_CONFIG_POINTER_FILE_PATH"] = FROZEN_TASK_CONFIG_POINTER_FILE_PATH.uri
os.environ["RESOURCE_CONFIG_PATH"] = RESOURCE_CONFIG_PATH.uri
Note on use of mocked assets#
This step is already done for you. We provide instructions below for posterity, incase the mocked data input “graph_config.yaml” is updated.
Note: unless you are a core contributor you will not have access to write to public BQ/GCS resources. In this case, can chose to update MOCK_DATA_GCS_BUCKET
and MOCK_DATA_BQ_DATASET_NAME
in python/gigl/src/mocking/lib/constants.py
to upload to your resources you own.
We run the following command to upload the relevant mocks to GCS and BQ:
python -m gigl.src.mocking.dataset_asset_mocking_suite \
--select mock_toy_graph_homogeneous_node_anchor_based_link_prediction_dataset \
--resource_config_uri=examples/toy_visual_example/resource_config.yaml
Subsequently, we can update the paths in task_config.yaml
Validating the configs#
We provide the ability to validate your resource and task configs. Although the validation is not exhaustive, it does help assert that the more common issues are not present schedule expensive compute is scheduled.
from gigl.src.validation_check.config_validator import kfp_validation_checks
validator = kfp_validation_checks(
job_name=JOB_NAME,
task_config_uri=TEMPLATE_TASK_CONFIG_PATH,
resource_config_uri=RESOURCE_CONFIG_PATH,
start_at="config_populator",
)
2025-06-04 00:19:25.518371: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-04 00:19:25.518424: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-04 00:19:25.519867: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-04 00:19 [INFO] : Using Any for unsupported type: typing.Sequence[~T] (native_type_compatibility.py:convert_to_beam_type:340)
2025-06-04 00:19 [INFO] : Config validation check: if job_name: gigl_test_job is valid. (name_checks.py:check_if_kfp_pipeline_job_name_valid:15)
2025-06-04 00:19 [INFO] : Creating symlink /tmp/tmpy67fu068 -> /home/svij/GiGL/examples/toy_visual_example/template_task_config.yaml (local_fs.py:create_file_symlinks:231)
2025-06-04 00:19 [INFO] : Deleted local file at /tmp/tmpy67fu068 (local_fs.py:remove_file_if_exist:64)
2025-06-04 00:19 [WARNING] : preprocessedMetadataUri is not set in the GbmlConfig. Please use ConfigPopulator to populate the preprocessedMetadata or designate it yourself. (gbml_config.py:__load_preprocessed_metadata_pb_wrapper:124)
2025-06-04 00:19 [INFO] : Skipping populating dataset_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:75)
2025-06-04 00:19 [INFO] : Skipping populating flattened_graph_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:90)
2025-06-04 00:19 [INFO] : Skipping populating subgraph_sampling_strategy_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:111)
2025-06-04 00:19 [INFO] : Creating symlink /tmp/tmp0_3cy7wd -> /home/svij/GiGL/examples/toy_visual_example/template_task_config.yaml (local_fs.py:create_file_symlinks:231)
2025-06-04 00:19 [INFO] : Deleted local file at /tmp/tmp0_3cy7wd (local_fs.py:remove_file_if_exist:64)
2025-06-04 00:19 [INFO] : Creating symlink /tmp/tmpfy_y06al -> /home/svij/GiGL/examples/toy_visual_example/resource_config.yaml (local_fs.py:create_file_symlinks:231)
2025-06-04 00:19 [INFO] : Deleted local file at /tmp/tmpfy_y06al (local_fs.py:remove_file_if_exist:64)
2025-06-04 00:19 [INFO] : Config validation check: if graphMetadata is valid. (template_config_checks.py:check_if_graph_metadata_valid:218)
2025-06-04 00:19 [INFO] : Config validation check: if taskMetadata is valid. (template_config_checks.py:check_if_task_metadata_valid:81)
2025-06-04 00:19 [INFO] : Config validation check: if dataPreprocessorConfigClsPath and its args are valid. (template_config_checks.py:check_if_data_preprocessor_config_cls_valid:238)
2025-06-04 00:19 [INFO] : Will try importing module: examples.toy_visual_example.toy_data_preprocessor_config.ToyDataPreprocessorConfig (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "examples.toy_visual_example.toy_data_preprocessor_config.ToyDataPreprocessorConfig". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'examples.toy_visual_example.toy_data_preprocessor_config.ToyDataPreprocessorConfig'; 'examples.toy_visual_example.toy_data_preprocessor_config' is not a package (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: examples.toy_visual_example.toy_data_preprocessor_config (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'examples.toy_visual_example.toy_data_preprocessor_config' from '/home/svij/GiGL/examples/toy_visual_example/toy_data_preprocessor_config.py'>, which potentially has object = ToyDataPreprocessorConfig (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access ToyDataPreprocessorConfig in <module 'examples.toy_visual_example.toy_data_preprocessor_config' from '/home/svij/GiGL/examples/toy_visual_example/toy_data_preprocessor_config.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Found BQ table reference.bq_nodes_table_name: external-snap-ci-github-gigl.public_gigl.toy_graph_homogeneous_node_anchor_lp_user_nodes_2025-06-02--04-42-21-UTC, bq_edges_table_name: external-snap-ci-github-gigl.public_gigl.toy_graph_homogeneous_node_anchor_lp_user-friend-user_edges_main_2025-06-02--04-42-21-UTC (toy_data_preprocessor_config.py:__init__:38)
2025-06-04 00:19 [INFO] : Config validation check: if subgraphSamplerConfig is valid. (template_config_checks.py:check_if_subgraph_sampler_config_valid:368)
2025-06-04 00:19 [WARNING] : preprocessedMetadataUri is not set in the GbmlConfig. Please use ConfigPopulator to populate the preprocessedMetadata or designate it yourself. (gbml_config.py:__load_preprocessed_metadata_pb_wrapper:124)
2025-06-04 00:19 [INFO] : Skipping populating dataset_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:75)
2025-06-04 00:19 [INFO] : Skipping populating flattened_graph_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:90)
2025-06-04 00:19 [INFO] : Skipping populating subgraph_sampling_strategy_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:111)
2025-06-04 00:19 [INFO] : Config validation check: if splitGeneratorConfig is valid. (template_config_checks.py:check_if_split_generator_config_valid:336)
2025-06-04 00:19 [WARNING] : preprocessedMetadataUri is not set in the GbmlConfig. Please use ConfigPopulator to populate the preprocessedMetadata or designate it yourself. (gbml_config.py:__load_preprocessed_metadata_pb_wrapper:124)
2025-06-04 00:19 [INFO] : Skipping populating dataset_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:75)
2025-06-04 00:19 [INFO] : Skipping populating flattened_graph_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:90)
2025-06-04 00:19 [INFO] : Skipping populating subgraph_sampling_strategy_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:111)
2025-06-04 00:19 [INFO] : Config validation check: if trainerClsPath and its args are valid. (template_config_checks.py:check_if_trainer_cls_valid:271)
2025-06-04 00:19 [WARNING] : preprocessedMetadataUri is not set in the GbmlConfig. Please use ConfigPopulator to populate the preprocessedMetadata or designate it yourself. (gbml_config.py:__load_preprocessed_metadata_pb_wrapper:124)
2025-06-04 00:19 [INFO] : Skipping populating dataset_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:75)
2025-06-04 00:19 [INFO] : Skipping populating flattened_graph_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:90)
2025-06-04 00:19 [INFO] : Skipping populating subgraph_sampling_strategy_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:111)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec.NodeAnchorBasedLinkPredictionModelingTaskSpec (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec.NodeAnchorBasedLinkPredictionModelingTaskSpec". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec.NodeAnchorBasedLinkPredictionModelingTaskSpec'; 'gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec' is not a package (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec' from '/home/svij/GiGL/python/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py'>, which potentially has object = NodeAnchorBasedLinkPredictionModelingTaskSpec (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access NodeAnchorBasedLinkPredictionModelingTaskSpec in <module 'gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec' from '/home/svij/GiGL/python/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.models.pyg.homogeneous.GraphSAGE (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "gigl.src.common.models.pyg.homogeneous.GraphSAGE". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'gigl.src.common.models.pyg.homogeneous.GraphSAGE'; 'gigl.src.common.models.pyg.homogeneous' is not a package (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.models.pyg.homogeneous (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'gigl.src.common.models.pyg.homogeneous' from '/home/svij/GiGL/python/gigl/src/common/models/pyg/homogeneous.py'>, which potentially has object = GraphSAGE (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access GraphSAGE in <module 'gigl.src.common.models.pyg.homogeneous' from '/home/svij/GiGL/python/gigl/src/common/models/pyg/homogeneous.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Will try importing module: torch.optim.Adam (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "torch.optim.Adam". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'torch.optim.Adam' (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: torch.optim (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'torch.optim' from '/opt/conda/envs/gnn/lib/python3.9/site-packages/torch/optim/__init__.py'>, which potentially has object = Adam (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access Adam in <module 'torch.optim' from '/opt/conda/envs/gnn/lib/python3.9/site-packages/torch/optim/__init__.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Will try importing module: torch.optim.lr_scheduler.ConstantLR (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "torch.optim.lr_scheduler.ConstantLR". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'torch.optim.lr_scheduler.ConstantLR'; 'torch.optim.lr_scheduler' is not a package (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: torch.optim.lr_scheduler (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'torch.optim.lr_scheduler' from '/opt/conda/envs/gnn/lib/python3.9/site-packages/torch/optim/lr_scheduler.py'>, which potentially has object = ConstantLR (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access ConstantLR in <module 'torch.optim.lr_scheduler' from '/opt/conda/envs/gnn/lib/python3.9/site-packages/torch/optim/lr_scheduler.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.models.layers.task.Retrieval (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "gigl.src.common.models.layers.task.Retrieval". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'gigl.src.common.models.layers.task.Retrieval'; 'gigl.src.common.models.layers.task' is not a package (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.models.layers.task (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'gigl.src.common.models.layers.task' from '/home/svij/GiGL/python/gigl/src/common/models/layers/task.py'>, which potentially has object = Retrieval (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access Retrieval in <module 'gigl.src.common.models.layers.task' from '/home/svij/GiGL/python/gigl/src/common/models/layers/task.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Identified task <class 'gigl.src.common.models.layers.task.Retrieval'> (node_anchor_based_link_prediction_modeling_task_spec.py:__init__:198)
2025-06-04 00:19 [INFO] : Config validation check: if inferencerClsPath and its args are valid. (template_config_checks.py:check_if_inferencer_cls_valid:302)
2025-06-04 00:19 [WARNING] : preprocessedMetadataUri is not set in the GbmlConfig. Please use ConfigPopulator to populate the preprocessedMetadata or designate it yourself. (gbml_config.py:__load_preprocessed_metadata_pb_wrapper:124)
2025-06-04 00:19 [INFO] : Skipping populating dataset_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:75)
2025-06-04 00:19 [INFO] : Skipping populating flattened_graph_metadata_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:90)
2025-06-04 00:19 [INFO] : Skipping populating subgraph_sampling_strategy_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:111)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec.NodeAnchorBasedLinkPredictionModelingTaskSpec (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec.NodeAnchorBasedLinkPredictionModelingTaskSpec". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec.NodeAnchorBasedLinkPredictionModelingTaskSpec'; 'gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec' is not a package (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec' from '/home/svij/GiGL/python/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py'>, which potentially has object = NodeAnchorBasedLinkPredictionModelingTaskSpec (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access NodeAnchorBasedLinkPredictionModelingTaskSpec in <module 'gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec' from '/home/svij/GiGL/python/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.models.pyg.homogeneous.GraphSAGE (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "gigl.src.common.models.pyg.homogeneous.GraphSAGE". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'gigl.src.common.models.pyg.homogeneous.GraphSAGE'; 'gigl.src.common.models.pyg.homogeneous' is not a package (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.models.pyg.homogeneous (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'gigl.src.common.models.pyg.homogeneous' from '/home/svij/GiGL/python/gigl/src/common/models/pyg/homogeneous.py'>, which potentially has object = GraphSAGE (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access GraphSAGE in <module 'gigl.src.common.models.pyg.homogeneous' from '/home/svij/GiGL/python/gigl/src/common/models/pyg/homogeneous.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Will try importing module: torch.optim.Adam (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "torch.optim.Adam". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'torch.optim.Adam' (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: torch.optim (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'torch.optim' from '/opt/conda/envs/gnn/lib/python3.9/site-packages/torch/optim/__init__.py'>, which potentially has object = Adam (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access Adam in <module 'torch.optim' from '/opt/conda/envs/gnn/lib/python3.9/site-packages/torch/optim/__init__.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Will try importing module: torch.optim.lr_scheduler.ConstantLR (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "torch.optim.lr_scheduler.ConstantLR". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'torch.optim.lr_scheduler.ConstantLR'; 'torch.optim.lr_scheduler' is not a package (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: torch.optim.lr_scheduler (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'torch.optim.lr_scheduler' from '/opt/conda/envs/gnn/lib/python3.9/site-packages/torch/optim/lr_scheduler.py'>, which potentially has object = ConstantLR (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access ConstantLR in <module 'torch.optim.lr_scheduler' from '/opt/conda/envs/gnn/lib/python3.9/site-packages/torch/optim/lr_scheduler.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.models.layers.task.Retrieval (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Unable to import "gigl.src.common.models.layers.task.Retrieval". (os_utils.py:_import_module:50)
2025-06-04 00:19 [INFO] : No module named 'gigl.src.common.models.layers.task.Retrieval'; 'gigl.src.common.models.layers.task' is not a package (os_utils.py:_import_module:51)
2025-06-04 00:19 [INFO] : Will try importing module: gigl.src.common.models.layers.task (os_utils.py:_import_module:41)
2025-06-04 00:19 [INFO] : Successfully imported module = <module 'gigl.src.common.models.layers.task' from '/home/svij/GiGL/python/gigl/src/common/models/layers/task.py'>, which potentially has object = Retrieval (os_utils.py:_import_module:45)
2025-06-04 00:19 [INFO] : Trying to access Retrieval in <module 'gigl.src.common.models.layers.task' from '/home/svij/GiGL/python/gigl/src/common/models/layers/task.py'> (os_utils.py:_find_obj_in_module:68)
2025-06-04 00:19 [INFO] : Identified task <class 'gigl.src.common.models.layers.task.Retrieval'> (node_anchor_based_link_prediction_modeling_task_spec.py:__init__:198)
2025-06-04 00:19 [INFO] : Config validation check: if postProcessorClsPath and its args are valid. (template_config_checks.py:check_if_post_processor_cls_valid:439)
2025-06-04 00:19 [INFO] : No post processor class provided - skipping checks for post processor (template_config_checks.py:check_if_post_processor_cls_valid:446)
2025-06-04 00:19 [INFO] : Config validation check: if resource config shared_resource is valid. (resource_config_checks.py:check_if_shared_resource_config_valid:69)
2025-06-04 00:19 [INFO] : [✅ SUCCESS] All checks passed successfully. (config_validator.py:kfp_validation_checks:237)
Config Populator#
Takes in a template GbmlConfig
and outputs a frozen GbmlConfig
by populating all job related metadata paths in
sharedConfig
. These are mostly GCS paths which the following components read and write from, and use as an
intermediary data communication medium. For example, the field sharedConfig.trainedModelMetadata
is populated with a
GCS URI, which indicates to the Trainer to write the trained model to this path, and to the Inferencer to read the model
from this path. See full Config Populator Guide
After running the command below we will have created a frozen config and ploaded it to the the perm_assets_bucket
provided in the resource config
.
The path to that file will be stored in the file @ FROZEN_TASK_CONFIG_POINTER_FILE_PATH
%%capture
# We suppress the output of this cell to avoid cluttering the notebook with logs. Remove %%%capture if you want to see the output.
!python -m \
gigl.src.config_populator.config_populator \
--job_name="$JOB_NAME" \
--template_uri="$TEMPLATE_TASK_CONFIG_PATH" \
--resource_config_uri="$RESOURCE_CONFIG_PATH" \
--output_file_path_frozen_gbml_config_uri="$FROZEN_TASK_CONFIG_POINTER_FILE_PATH"
# The command above will write the frozen task config path to the file specified by `FROZEN_TASK_CONFIG_POINTER_FILE_PATH`.
# Lets see where it was generated
FROZEN_TASK_CONFIG_PATH: Uri
with open(FROZEN_TASK_CONFIG_POINTER_FILE_PATH.uri, 'r') as file:
FROZEN_TASK_CONFIG_PATH = UriFactory.create_uri(file.read().strip())
print(f"FROZEN_TASK_CONFIG_PATH: {FROZEN_TASK_CONFIG_PATH}")
FROZEN_TASK_CONFIG_PATH: gs://gigl-cicd-temp/gigl_test_job/config_populator/frozen_gbml_config.yaml
Visualizing the diff between template and frozen config.#
We now have a frozen task config, path specified by FROZEN_TASK_CONFIG_PATH
.
We visualize the diff between the frozen_task_config
generated by the config_populator
and the original template_task_config
.
All the code below is just to just that, and has nothing to do with GiGL.
Specifically make note that:
The component added
sharedConfig
to the yaml, which contains all of the intermediary and final output paths for each component.It also added a
condensedEdgeTypeMap
and acondensedNodeTypeMap
, which maps all provided edge types and node types toint
to save storage space:EdgeType: Tuple[srcNodeType: str, relation:str, dstNodeType:str)] -> int
, andNodeType: str -> int
Note: You may also provide your own condesedMaps, they will be generated for you if not provided
import yaml
from difflib import unified_diff
from IPython.display import display, HTML
from gigl.src.common.utils.file_loader import FileLoader
def sort_yaml_dict_recursively(obj: dict) -> dict:
# We sort the yaml recursively as the GiGL proto serialization code does not guarantee order of original keys.
# This is important for the diff to be stable and not show errors due to key/list order changes.
if isinstance(obj, dict):
return {k: sort_yaml_dict_recursively(obj[k]) for k in sorted(obj)}
elif isinstance(obj, list):
return [sort_yaml_dict_recursively(item) for item in obj]
else:
return obj
def show_colored_unified_diff(f1_lines, f2_lines, f1_name, f2_name):
diff_lines = list(unified_diff(f1_lines, f2_lines, fromfile=f1_name, tofile=f2_name))
html_lines = []
for line in diff_lines:
if line.startswith('+') and not line.startswith('+++'):
color = '#228B22' # green
elif line.startswith('-') and not line.startswith('---'):
color = '#B22222' # red
elif line.startswith('@'):
color = '#1E90FF' # blue
else:
color = "#000000" # black
html_lines.append(f'<pre style="margin:0; color:{color}; background-color:white;">{line.rstrip()}</pre>')
display(HTML(''.join(html_lines)))
file_loader = FileLoader()
frozen_task_config_file_contents: str
template_task_config_file_contents: str
with open(file_loader.load_to_temp_file(file_uri_src=FROZEN_TASK_CONFIG_PATH).name, 'r') as f:
data = yaml.safe_load(f)
# sort_keys by default
frozen_task_config_file_contents = yaml.dump(sort_yaml_dict_recursively(data))
with open(file_loader.load_to_temp_file(file_uri_src=TEMPLATE_TASK_CONFIG_PATH).name, 'r') as f:
data = yaml.safe_load(f)
template_task_config_file_contents = yaml.dump(sort_yaml_dict_recursively(data))
# Example usage
show_colored_unified_diff(
template_task_config_file_contents.splitlines(),
frozen_task_config_file_contents.splitlines(),
f1_name='template_task_config.yaml',
f2_name='frozen_task_config.yaml'
)
2025-06-04 00:19 [INFO] : Creating symlink /tmp/tmp38v2re8k -> /home/svij/GiGL/examples/toy_visual_example/template_task_config.yaml (local_fs.py:create_file_symlinks:231)
2025-06-04 00:19 [INFO] : Deleted local file at /tmp/tmp38v2re8k (local_fs.py:remove_file_if_exist:64)
--- template_task_config.yaml
+++ frozen_task_config.yaml
@@ -17,6 +17,13 @@
numNeighborsToSample: 2
numPositiveSamples: 1
graphMetadata:
+ condensedEdgeTypeMap:
+ '0':
+ dstNodeType: user
+ relation: is_friends_with
+ srcNodeType: user
+ condensedNodeTypeMap:
+ '0': user
edgeTypes:
- dstNodeType: user
relation: is_friends_with
@@ -29,6 +36,35 @@
num_layers: '2'
out_dim: '128'
inferencerClsPath: gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec.NodeAnchorBasedLinkPredictionModelingTaskSpec
+sharedConfig:
+ datasetMetadata:
+ nodeAnchorBasedLinkPredictionDataset:
+ testMainDataUri: gs://gigl-cicd-temp/gigl_test_job/split_generator/test/main_samples/samples/
+ testNodeTypeToRandomNegativeDataUri:
+ user: gs://gigl-cicd-temp/gigl_test_job/split_generator/test/random_negatives/user/neighborhoods/
+ trainMainDataUri: gs://gigl-cicd-temp/gigl_test_job/split_generator/train/main_samples/samples/
+ trainNodeTypeToRandomNegativeDataUri:
+ user: gs://gigl-cicd-temp/gigl_test_job/split_generator/train/random_negatives/user/neighborhoods/
+ valMainDataUri: gs://gigl-cicd-temp/gigl_test_job/split_generator/val/main_samples/samples/
+ valNodeTypeToRandomNegativeDataUri:
+ user: gs://gigl-cicd-temp/gigl_test_job/split_generator/val/random_negatives/user/neighborhoods/
+ flattenedGraphMetadata:
+ nodeAnchorBasedLinkPredictionOutput:
+ nodeTypeToRandomNegativeTfrecordUriPrefix:
+ user: gs://gigl-cicd-temp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/random_negative_rooted_neighborhood_samples/user/samples/
+ tfrecordUriPrefix: gs://gigl-cicd-temp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/node_anchor_based_link_prediction_samples/samples/
+ inferenceMetadata:
+ nodeTypeToInferencerOutputInfoMap:
+ user:
+ embeddingsPath: external-snap-ci-github-gigl.gigl_temp_assets.embeddings_user_gigl_test_job
+ postprocessedMetadata:
+ postProcessorLogMetricsUri: gs://gigl-cicd-temp/gigl_test_job/post_processor/post_processor_metrics.json
+ preprocessedMetadataUri: gs://gigl-cicd-temp/gigl_test_job/data_preprocess/preprocessed_metadata.yaml
+ trainedModelMetadata:
+ evalMetricsUri: gs://gigl-cicd-temp/gigl_test_job/trainer/models/trainer_eval_metrics.json
+ scriptedModelUri: gs://gigl-cicd-temp/gigl_test_job/trainer/models/scripted_model.pt
+ tensorboardLogsUri: gs://gigl-cicd-temp/gigl_test_job/trainer/tensorboard_logs/
+ trainedModelUri: gs://gigl-cicd-temp/gigl_test_job/trainer/models/model.pt
taskMetadata:
nodeAnchorBasedLinkPredictionTaskMetadata:
supervisionEdgeTypes:
# We will load the frozen task and resource configs file into an object so we can reference it in the next cells
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.gigl_resource_config import GiglResourceConfigWrapper
frozen_task_config = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
gbml_config_uri=FROZEN_TASK_CONFIG_PATH
)
resource_config: GiglResourceConfigWrapper = get_resource_config(
resource_config_uri=RESOURCE_CONFIG_PATH
)
2025-06-04 00:19 [INFO] : Skipping populating subgraph_sampling_strategy_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:111)
Compiling Src Docker images#
You will need to build and push docker images with your custom code so that individual GiGL components can leverage your code. For this experiment we will consider the toy_visual_example specs and relevant to be “custom code”, and we will guide you how to build a docker image with the code.
We will make use of scripts/build_and_push_docker_image.py
for this.
Make note that this builds containers/Dockerfile.src
and containers/Dockerfile.dataflow.src
; which have instructions to COPY
the examples
folder - which contains all the source code for this example, and it has all the GiGL src code.
# We suppress the output of this cell to avoid cluttering the notebook with logs. Remove %%%capture if you want to see the output.
from scripts.build_and_push_docker_image import build_and_push_cpu_image, build_and_push_cuda_image, build_and_push_dataflow_image
import datetime
project = resource_config.project
curr_datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Change to whereever you want to store the docker images.
DOCKER_ARTIFACT_REGISTRY = f"us-central1-docker.pkg.dev/{project}/gigl-base-images"
DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG = f"{DOCKER_ARTIFACT_REGISTRY}/gigl_dataflow_runtime:{curr_datetime}"
DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG = f"{DOCKER_ARTIFACT_REGISTRY}/gigl_cuda:{curr_datetime}"
DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG = f"{DOCKER_ARTIFACT_REGISTRY}/gigl_cpu:{curr_datetime}"
os.environ["DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG"] = DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG
os.environ["DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG"] = DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG
os.environ["DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG"] = DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG
print(DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG)
build_and_push_dataflow_image(
image_name=DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG,
)
build_and_push_cuda_image(
image_name=DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG,
)
build_and_push_cpu_image(
image_name=DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG,
)
print(f"""We built and pushed the following docker images:
- {DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG}
- {DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG}
- {DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG}
""")
Data Preprocessor#
Once we have a frozen_task_config
, the first step is to preprocess the data.
The Data Preprocessor component uses Tensorflow Transform to achieve data transformation in a distributed fashion.
Any custom preprocessing is to be defined in the preprocessor class, specified in the task config by
datasetConfig.dataPreprocessorConfig.dataPreprocessorConfigClsPath
.This class must inherit from
gigl.src.data_preprocessor.lib.data_preprocessor_config.DataPreprocessorConfig
In your preprocessor spec you must implement the following 3 functions as defined by the base class DataPreprocessorConfig
:
prepare_for_pipeline
: Preparing datasets for ingestion and transformationget_nodes_preprocessing_spec
: Defining transformation imperatives for different node typesget_edges_preprocessing_spec
: Defining transformation imperatives for different edge types
Please take a look at toy_data_preprocessor_config.py to see how these are defined. You will note that in this case, we are not doing anything special i.e. no feature engineering, just read from BQ and pass through the features. We could if we wanted define our own preprocessing function, and replace it with build_passthrough_transform_preprocessing_fn()
defined in code.
Input params and Output Paths for Data Preprocessor#
Lets take a quick look at what these are from our frozen config.
print("Frozen Config Datapreprocessor Information:")
print("- Data Preprocessor Config: Specifies what class to use for datapreprocessing and any arguments that might be passed in at runtime to that class")
print(textwrap.indent(str(frozen_task_config.dataset_config.data_preprocessor_config), '\t'))
print("- Preprocessed Metadata Uri: Specifies path to the preprocessed metadata file that will be generated by this component and used by subsequent components to understand and find the data that was preprocessed")
print(textwrap.indent(str(frozen_task_config.shared_config.preprocessed_metadata_uri), '\t'))
Frozen Config Datapreprocessor Information:
- Data Preprocessor Config: Specifies what class to use for datapreprocessing and any arguments that might be passed in at runtime to that class
data_preprocessor_config_cls_path: "examples.toy_visual_example.toy_data_preprocessor_config.ToyDataPreprocessorConfig"
data_preprocessor_args {
key: "bq_edges_table_name"
value: "external-snap-ci-github-gigl.public_gigl.toy_graph_homogeneous_node_anchor_lp_user-friend-user_edges_main_2025-06-02--04-42-21-UTC"
}
data_preprocessor_args {
key: "bq_nodes_table_name"
value: "external-snap-ci-github-gigl.public_gigl.toy_graph_homogeneous_node_anchor_lp_user_nodes_2025-06-02--04-42-21-UTC"
}
- Preprocessed Metadata Uri: Specifies path to the preprocessed metadata file that will be generated by this component and used by subsequent components to understand and find the data that was preprocessed
gs://gigl-cicd-temp/gigl_test_job/data_preprocess/preprocessed_metadata.yaml
Running Data Preprocessor and visualizing the Preprocessed Metadata#
%%capture
!python -m gigl.src.data_preprocessor.data_preprocessor \
--job_name=$JOB_NAME \
--task_config_uri=$FROZEN_TASK_CONFIG_PATH \
--resource_config_uri=$RESOURCE_CONFIG_PATH \
--custom_worker_image_uri=$DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG
Upon completion of job, we will see the preprocessed metadata be populated#
preprocessed_metadata_pb = frozen_task_config.preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb
print(preprocessed_metadata_pb)
condensed_node_type_to_preprocessed_metadata {
key: 0
value {
node_id_key: "node_id"
feature_keys: "f0"
feature_keys: "f1"
tfrecord_uri_prefix: "gs://gigl-cicd-temp/gigl_test_job/data_preprocess/staging/transformed_node_features_dir/user/features/"
schema_uri: "gs://gigl-cicd-temp/gigl_test_job/data_preprocess/node/user/tft_transform_dir/transformed_metadata/schema.pbtxt"
enumerated_node_ids_bq_table: "external-snap-ci-github-gigl.gigl_temp_assets.enumerated_node_user_ids_gigl_test_job"
enumerated_node_data_bq_table: "external-snap-ci-github-gigl.gigl_temp_assets.enumerated_node_user_node_features_gigl_test_job"
feature_dim: 2
transform_fn_assets_uri: "gs://gigl-cicd-temp/gigl_test_job/data_preprocess/node/user/tft_transform_dir/transform_fn/assets"
}
}
condensed_edge_type_to_preprocessed_metadata {
key: 0
value {
src_node_id_key: "src"
dst_node_id_key: "dst"
main_edge_info {
tfrecord_uri_prefix: "gs://gigl-cicd-temp/gigl_test_job/data_preprocess/staging/transformed_edge_features_dir/user-is_friends_with-user/main/features/"
schema_uri: "gs://gigl-cicd-temp/gigl_test_job/data_preprocess/edge/user-is_friends_with-user/main/tft_transform_dir/transformed_metadata/schema.pbtxt"
enumerated_edge_data_bq_table: "external-snap-ci-github-gigl.gigl_temp_assets.enumerated_edge_user-is_friends_with-user_main_edge_features_gigl_test_job"
transform_fn_assets_uri: "gs://gigl-cicd-temp/gigl_test_job/data_preprocess/edge/user-is_friends_with-user/main/tft_transform_dir/transform_fn/assets"
}
}
}
You do not have to worry about these details in code as it is all handled by the data preprocessor component and subsequent data loaders But, for the sake of understanding, we will investigate the condensed_node_type = 0 and condensed_edge_type = 0 If you remember the from the frozen config the mappings were as follows:
print("Condensed Node Type Mapping:")
print(textwrap.indent(str(frozen_task_config.graph_metadata.condensed_node_type_map), '\t'))
print("Condensed Edge Type Mapping:")
print(textwrap.indent(str(frozen_task_config.graph_metadata.condensed_edge_type_map), '\t'))
preprocessed_nodes = preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[0].tfrecord_uri_prefix
preprocessed_edges = preprocessed_metadata_pb.condensed_edge_type_to_preprocessed_metadata[0].main_edge_info.tfrecord_uri_prefix
print(f"Preprocessed Nodes are stored in: {preprocessed_nodes}")
print(f"Preprocessed Edges are stored in: {preprocessed_edges}")
Condensed Node Type Mapping:
{0: 'user'}
Condensed Edge Type Mapping:
{0: relation: "is_friends_with"
src_node_type: "user"
dst_node_type: "user"
}
Preprocessed Nodes are stored in: gs://gigl-cicd-temp/gigl_test_job/data_preprocess/staging/transformed_node_features_dir/user/features/
Preprocessed Edges are stored in: gs://gigl-cicd-temp/gigl_test_job/data_preprocess/staging/transformed_edge_features_dir/user-is_friends_with-user/main/features/
There is not a lot of data so we will have likely just generated one file for each of the preprocessed nodes and edges.
!gsutil ls $preprocessed_nodes && gsutil ls $preprocessed_edges
gs://gigl-cicd-temp/gigl_test_job/data_preprocess/staging/transformed_node_features_dir/user/features/-00000-of-00001.tfrecord
gs://gigl-cicd-temp/gigl_test_job/data_preprocess/staging/transformed_edge_features_dir/user-is_friends_with-user/main/features/-00000-of-00001.tfrecord
Subgraph Sampler#
The Subgraph Sampler receives node and edge data from Data Preprocessor and generates k-hop localized subgraphs for each node in the graph. The purpose is to store the neighborhood of each node independently, and as a result reducing the memory footprint for down-stream components, as they need not load the entire graph into memory but only batches of these node neighborhoods. To run subgraph sampler we use the following command:
%%capture
!python -m gigl.src.subgraph_sampler.subgraph_sampler \
--job_name=$JOB_NAME \
--task_config_uri=$FROZEN_TASK_CONFIG_PATH \
--resource_config_uri=$RESOURCE_CONFIG_PATH
Upon completion, there will be two different directories of subgraph samples. One is the main node anchor based link prediction samples and the other is random negative rooted neigborhood samples which are stored in the locations specified in the frozen_config:
flattened_graph_metadata = frozen_task_config.shared_config.flattened_graph_metadata
print(flattened_graph_metadata)
node_anchor_based_link_prediction_output {
tfrecord_uri_prefix: "gs://gigl-cicd-temp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/node_anchor_based_link_prediction_samples/samples/"
node_type_to_random_negative_tfrecord_uri_prefix {
key: "user"
value: "gs://gigl-cicd-temp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/random_negative_rooted_neighborhood_samples/user/samples/"
}
}
The main, unsupervised_node_anchor_based_link_prediction_samples include root nodes khop neighborhood, positive nodes khop neighborhood and positive edges. These samples will be used for training. The random_negative_rooted_neighborhood_samples (which include root nodes khop neighborhood)samples are double purpose: they will be used for inferencer and random negative samples for training.
The random negative are used for the model to be able to learn non-existent (negative) edges since it could overfit on just positive samples. This means it would fail to generalize well to unseen data. The negative edges are just an edge chosen at random. At a large scale, this would most probably be a negative edge.
Below we visualize the Root Node Neighbourhood of 5, the Root Node Neighbourhood of its pos_edge’s destination node (1) and the resulting sample for root node 5.
Since subgraph sampler is sampling randomly here, you will get different subgraphs every time you run this. For the purposes of example we also provide some screen shots of what these graphs might look like:
When training you may see a sample for node 9 as follows. Specifially note that edge 9 --> 7
is classified as a positive edge.
Where SGS has sampled 2-hop subgraph with incoming edges to both nodes 9 and 7.

Secondly, we may choose to randomly sample a rooted node neighborhood to act as a “negative sample” i.e. in this case we sample node 1 and can have 9 --> 1
edge be a “negative sample”

# We will sample node with this id and visualize it. You will see positive edge marked in red and the root node with a black border.
SAMPLE_NODE_ID = 9
# We will sample random negative nodes with these ids and visualize them. You will see the root node with a black border.
SAMPLE_RANDOM_NEGATIVE_NODE_IDS = [1, 3]
import tensorflow as tf
from typing import Union, Literal
from snapchat.research.gbml import training_samples_schema_pb2
from examples.toy_visual_example.visualize import GraphVisualizer
def find_node_pb(
tfrecord_uri_prefix: str,
node_id: int,
pb_type: Union[
Literal['NodeAnchorBasedLinkPredictionSample'],
Literal['RootedNodeNeighborhood']
]):
uri = tfrecord_uri_prefix + "*.tfrecord"
ds = tf.data.TFRecordDataset(tf.io.gfile.glob(uri)).as_numpy_iterator()
for bytestr in ds:
try:
if pb_type == 'RootedNodeNeighborhood':
pb = training_samples_schema_pb2.RootedNodeNeighborhood()
elif pb_type == 'NodeAnchorBasedLinkPredictionSample':
pb = training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample()
pb.ParseFromString(bytestr)
if pb.root_node.node_id == node_id:
return pb
except StopIteration:
break
print(f"The original global graph:")
GraphVisualizer.visualize_graph(original_graph_heterodata)
sample = find_node_pb(
tfrecord_uri_prefix=flattened_graph_metadata.node_anchor_based_link_prediction_output.tfrecord_uri_prefix,
node_id=SAMPLE_NODE_ID,
pb_type='NodeAnchorBasedLinkPredictionSample'
)
print(f"Node anchor prediction sample for node {SAMPLE_NODE_ID}:")
GraphVisualizer.plot_graph(sample)
for random_negative_node_id in SAMPLE_RANDOM_NEGATIVE_NODE_IDS:
random_negative_sample = find_node_pb(
tfrecord_uri_prefix=flattened_graph_metadata.node_anchor_based_link_prediction_output.node_type_to_random_negative_tfrecord_uri_prefix["user"],
node_id=random_negative_node_id,
pb_type='RootedNodeNeighborhood'
)
print(f"Random negative sample for node {random_negative_node_id}:")
GraphVisualizer.plot_graph(random_negative_sample)
The original global graph:

Node anchor prediction sample for node 9:

Random negative sample for node 1:

Random negative sample for node 3:

Split Generator#
The Split Generator reads localized subgraph samples produced by Subgraph Sampler, and executes the user specified split strategy logic to split the data into training, validation and test sets. Several standard configurations of SplitStrategy and corresponding Assigner classes are implemented already at a GiGL platform-level: transductive node classification, inductive node classification, and transductive link prediction split routines. For more information on split strategies in Graph Machine Learning checkout these resources:
http://web.stanford.edu/class/cs224w/slides/07-theory.pdf
https://zqfang.github.io/2021-08-12-graph-linkpredict/ (relevant for explaining transductive vs inductive)
In this example, we are using the transductive strategy as specified in our frozen_config:
print(frozen_task_config.dataset_config.split_generator_config)
split_strategy_cls_path: "splitgenerator.lib.split_strategies.TransductiveNodeAnchorBasedLinkPredictionSplitStrategy"
assigner_cls_path: "splitgenerator.lib.assigners.TransductiveEdgeToLinkSplitHashingAssigner"
assigner_args {
key: "seed"
value: "42"
}
assigner_args {
key: "test_split"
value: "0.2"
}
assigner_args {
key: "train_split"
value: "0.7"
}
assigner_args {
key: "val_split"
value: "0.1"
}
For transductive, at training time, it uses training message edges to predict training supervision edges. At validation time, the training message edges and training supervision edges are used to predict the validation edges and then all 3 are used to predict test edges. Below is the command to run split generator:
%%capture
!python -m gigl.src.split_generator.split_generator \
--job_name=$JOB_NAME \
--task_config_uri=$FROZEN_TASK_CONFIG_PATH \
--resource_config_uri=$RESOURCE_CONFIG_PATH
Upon completion, there will be 3 folders for train,test, and val. Each of them contains the protos for the positive and negaitve samples. The path for these folders is specified in the following location in the frozen_config:
dataset_metadata = frozen_task_config.shared_config.dataset_metadata
print(dataset_metadata)
node_anchor_based_link_prediction_dataset {
train_main_data_uri: "gs://gigl-cicd-temp/gigl_test_job/split_generator/train/main_samples/samples/"
test_main_data_uri: "gs://gigl-cicd-temp/gigl_test_job/split_generator/test/main_samples/samples/"
val_main_data_uri: "gs://gigl-cicd-temp/gigl_test_job/split_generator/val/main_samples/samples/"
train_node_type_to_random_negative_data_uri {
key: "user"
value: "gs://gigl-cicd-temp/gigl_test_job/split_generator/train/random_negatives/user/neighborhoods/"
}
val_node_type_to_random_negative_data_uri {
key: "user"
value: "gs://gigl-cicd-temp/gigl_test_job/split_generator/val/random_negatives/user/neighborhoods/"
}
test_node_type_to_random_negative_data_uri {
key: "user"
value: "gs://gigl-cicd-temp/gigl_test_job/split_generator/test/random_negatives/user/neighborhoods/"
}
}
We can visualize the train,test, and val samples for all nodes. Note, for this specific task config setup, not all val amd test samples will have positive labels. This is because edges are randomly assigned into “train”, “val”, or “test” buckets independent of whether or not they are supervision edges. Thus, although at scale this setting is okay i.e. with large data and a large batch size each batch will have some supervision edges, it can be a case that certain batches dont have any supervision edges. Thus, your val/test loops and early stopping logic may need to be carefully designed.
for node_id in range(original_graph_heterodata.num_nodes):
print(f"Node anchor prediction sample for node {node_id}:")
sample_train = find_node_pb(
tfrecord_uri_prefix=dataset_metadata.node_anchor_based_link_prediction_dataset.train_main_data_uri,
node_id=node_id,
pb_type='NodeAnchorBasedLinkPredictionSample'
)
sample_val = find_node_pb(
tfrecord_uri_prefix=dataset_metadata.node_anchor_based_link_prediction_dataset.val_main_data_uri,
node_id=node_id,
pb_type='NodeAnchorBasedLinkPredictionSample'
)
sample_test = find_node_pb(
tfrecord_uri_prefix=dataset_metadata.node_anchor_based_link_prediction_dataset.test_main_data_uri,
node_id=node_id,
pb_type='NodeAnchorBasedLinkPredictionSample'
)
if sample_train:
print(f"Train sample for node {node_id}: ")
GraphVisualizer.plot_graph(sample_train)
else:
print(f"No train sample found for node {node_id}.")
if sample_val:
print(f"Validation sample for node {node_id}:")
GraphVisualizer.plot_graph(sample_val)
else:
print(f"No validation sample found for node {node_id}.")
if sample_test:
print(f"Test sample for node {node_id}:")
GraphVisualizer.plot_graph(sample_test)
else:
print(f"No test sample found for node {node_id}.")
Node anchor prediction sample for node 0:
Train sample for node 0:

Validation sample for node 0:

Test sample for node 0:

Node anchor prediction sample for node 1:
No train sample found for node 1.
Validation sample for node 1:

Test sample for node 1:

Node anchor prediction sample for node 2:
Train sample for node 2:

Validation sample for node 2:

Test sample for node 2:

Node anchor prediction sample for node 3:
No train sample found for node 3.
Validation sample for node 3:

Test sample for node 3:

Node anchor prediction sample for node 4:
Train sample for node 4:

Validation sample for node 4:

Test sample for node 4:

Node anchor prediction sample for node 5:
Train sample for node 5:

Validation sample for node 5:

Test sample for node 5:

Node anchor prediction sample for node 6:
Train sample for node 6:

Validation sample for node 6:

Test sample for node 6:

Node anchor prediction sample for node 7:
Train sample for node 7:

Validation sample for node 7:

Test sample for node 7:

Node anchor prediction sample for node 8:
Train sample for node 8:

Validation sample for node 8:

Test sample for node 8:

Node anchor prediction sample for node 9:
Train sample for node 9:

Validation sample for node 9:

Test sample for node 9:

Node anchor prediction sample for node 10:
Train sample for node 10:

Validation sample for node 10:

Test sample for node 10:

Node anchor prediction sample for node 11:
Train sample for node 11:

Validation sample for node 11:

Test sample for node 11:

Node anchor prediction sample for node 12:
Train sample for node 12:

Validation sample for node 12:

Test sample for node 12:

Node anchor prediction sample for node 13:
Train sample for node 13:

Validation sample for node 13:

Test sample for node 13:

At this point, we have our graph data samples ready to be processed by the trainer and inferencer components. These components will extract representations/embeddings by learning contextual information for the specified task.