Toy Example - Tabularized GiGL#
Latest version of this notebook can be found on github
This notebook provides a walkthrough of preprocessing, subgraph sampling, and split generation components with a small toy graph for GiGL’s Tabularized setting for training/inference. It will help you understand how each of these components prepare tabularized subgraphs.
Overview Of Components#
This notebook demonstrates the process of a simple, human-digestible graph being passed through all the pipeline components in GiGL in preparation for training to help understand how each of the components work.
The pipeline consists of the following components:
Config Populator: Takes a template config and creates a frozen workflow config that dictates all inputs/outputs and business parameters that are read and used by each subsequent component.
Input:
template_config.yaml
Output:
frozen_gbml_config.yaml
Data Preprocessor: Transforms necessary node and edge feature assets as needed as a precursor step in most ML tasks according to the user-provided data preprocessor config class.
Input:
frozen_gbml_config.yaml
, which includes the user-defined preprocessor class for custom logic, and custom arguments can be passed under dataPreprocessorArgs.Output: PreprocessedMetadata Proto, which includes inferred GraphMetadata and preprocessed graph data Tfrecords after applying the user-defined preprocessing function.
Subgraph Sampler: Samples k-hop subgraphs for each node according to user-provided arguments.
Input:
frozen_gbml_config.yaml
,resource_config.yaml
Output: Subgraph Samples (tfrecord format based on predefined schema in protos) are stored in the URI defined in the flattenedGraphMetadata field.
Split Generator: Splits subgraph sampler outputs into train/test/val sets according to the user-provided split strategy class.
Input:
frozen_gbml_config.yaml
, which includes an instance of SplitStrategy and an instance of Assigner.Output: TFRecord samples.
Trainer: The trainer component reads the output of the split generator and trains a model on the training set, stops based on the validation set, and evaluates on the test set.
Input:
frozen_gbml_config.yaml
Output: state_dict stored in trainedModelUri.
Inferencer: Runs inference of a trained model on samples generated by the Subgraph Sampler.
Input:
frozen_gbml_config.yaml
Output: Embeddings and/or prediction assets.
%load_ext autoreload
%autoreload 2
from gigl.common.utils.jupyter_magics import change_working_dir_to_gigl_root
change_working_dir_to_gigl_root()
Input Graph#
We use the input graph defined in examples/toy_visual_example/graph_config.yaml. You are welcome to change this file to a custom graph of your own choosing.
Visualizing the Input Graph#
from torch_geometric.data import HeteroData
from gigl.common.utils.jupyter_magics import GraphVisualizer
from gigl.src.mocking.toy_asset_mocker import load_toy_graph
original_graph_heterodata: HeteroData = load_toy_graph(graph_config_path="examples/toy_visual_example/graph_config.yaml")
# Visualize the graph
GraphVisualizer.visualize_graph(original_graph_heterodata)
2025-06-26 21:07:55.399569: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-26 21:07:55.399629: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-26 21:07:55.401809: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-26 21:07:55.412870: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

Setting up Configs#
The first thing we need to do is create the resource and task configs.
Task Config: Specifies task-related configurations, guiding the behavior of components according to the needs of your machine learning task. See Task Config Guide. For this task, we have already provided a task config: task_config.yaml.
Resource Config: Details the resource allocation and environmental settings across all GiGL components. This encompasses shared resources for all components, as well as component-specific settings. See Resource Config Guide. For this task, we provide a resource resource_config.yaml. The provided default values in
shared_resource_config.common_compute_config
will need to be changed.Instructions to configure the resource config to work: If you have not already, please follow the Quick Start Guide to set up your cloud environment and create a default test resource config. You can then copy the relevant
shared_resource_config.common_compute_config
to resource_config.yaml.
import os
import pathlib
import textwrap
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from gigl.common import Uri, UriFactory
notebook_dir = pathlib.Path("./examples/toy_visual_example").as_posix() # We should be in root dir because of cell # 1
# Add the root gigl dir to the Python path so `example` folder can be imported as a module.
# You are welcome to customize these to point to your own configuration files.
JOB_NAME = "gigl_test_job"
TEMPLATE_TASK_CONFIG_PATH: Uri = UriFactory.create_uri(f"{notebook_dir}/template_task_config.yaml")
FROZEN_TASK_CONFIG_POINTER_FILE_PATH: Uri = UriFactory.create_uri(f"/tmp/GiGL/{JOB_NAME}/frozen_task_config.yaml")
pathlib.Path(FROZEN_TASK_CONFIG_POINTER_FILE_PATH.uri).parent.mkdir(parents=True, exist_ok=True)
# Ensure you change the resource config path to point to your own resource configuration
# i.e. what was exported to $GIGL_TEST_DEFAULT_RESOURCE_CONFIG as part of the quick start guide.
RESOURCE_CONFIG_PATH: Uri = UriFactory.create_uri(f"{notebook_dir}/resource_config.yaml")
# Export string format of the uris so we can reference them in cells that execute bash commands below.
os.environ["TEMPLATE_TASK_CONFIG_PATH"] = TEMPLATE_TASK_CONFIG_PATH.uri
os.environ["FROZEN_TASK_CONFIG_POINTER_FILE_PATH"] = FROZEN_TASK_CONFIG_POINTER_FILE_PATH.uri
os.environ["RESOURCE_CONFIG_PATH"] = RESOURCE_CONFIG_PATH.uri
Note on the Use of Mocked Assets#
This step is already done for you. We provide instructions below for posterity, in case the mocked data input “graph_config.yaml” is updated.
You can choose to update MOCK_DATA_GCS_BUCKET
and MOCK_DATA_BQ_DATASET_NAME
in python/gigl/src/mocking/lib/constants.py
to upload to resources you own.
We run the following command to upload the relevant mocks to GCS and BQ:
python -m gigl.src.mocking.dataset_asset_mocking_suite \
--select mock_toy_graph_homogeneous_node_anchor_based_link_prediction_dataset \
--resource_config_uri=examples/toy_visual_example/resource_config.yaml
Subsequently, we can update the paths in task_config.yaml.
Validating the Configs#
We provide the ability to validate your resource and task configs. Although the validation is not exhaustive, it does help assert that the more common issues are not present before expensive compute is scheduled.
from gigl.src.validation_check.config_validator import kfp_validation_checks
validator = kfp_validation_checks(
job_name=JOB_NAME,
task_config_uri=TEMPLATE_TASK_CONFIG_PATH,
resource_config_uri=RESOURCE_CONFIG_PATH,
start_at="config_populator",
)
Config Populator#
Takes in a template GbmlConfig
and outputs a frozen GbmlConfig
by populating all job-related metadata paths in sharedConfig
. These are mostly GCS paths that the following components read and write from, and use as an intermediary data communication medium. For example, the field sharedConfig.trainedModelMetadata
is populated with a GCS URI, which indicates to the Trainer to write the trained model to this path, and to the Inferencer to read the model from this path. See the full Config Populator Guide.
After running the command below, we will have created a frozen config and uploaded it to the perm_assets_bucket
provided in the resource config
. The path to that file will be stored in the file at FROZEN_TASK_CONFIG_POINTER_FILE_PATH
.
%%capture
# We suppress the output of this cell to avoid cluttering the notebook with logs. Remove %%%capture if you want to see the output.
!python -m \
gigl.src.config_populator.config_populator \
--job_name="$JOB_NAME" \
--template_uri="$TEMPLATE_TASK_CONFIG_PATH" \
--resource_config_uri="$RESOURCE_CONFIG_PATH" \
--output_file_path_frozen_gbml_config_uri="$FROZEN_TASK_CONFIG_POINTER_FILE_PATH"
# The command above will write the frozen task config path to the file specified by `FROZEN_TASK_CONFIG_POINTER_FILE_PATH`.
# Lets see where it was generated
FROZEN_TASK_CONFIG_PATH: Uri
with open(FROZEN_TASK_CONFIG_POINTER_FILE_PATH.uri, 'r') as file:
FROZEN_TASK_CONFIG_PATH = UriFactory.create_uri(file.read().strip())
print(f"FROZEN_TASK_CONFIG_PATH: {FROZEN_TASK_CONFIG_PATH}")
FROZEN_TASK_CONFIG_PATH: gs://svij-gigl-oss-perm/gigl_test_job/config_populator/frozen_gbml_config.yaml
Visualizing the Diff Between Template and Frozen Config#
We now have a frozen task config, with the path specified by FROZEN_TASK_CONFIG_PATH
. We visualize the diff between the frozen_task_config
generated by the config_populator
and the original template_task_config
. All the code below is just to do that and has nothing to do with GiGL.
Specifically, note that:
The component added
sharedConfig
to the YAML, which contains all the intermediary and final output paths for each component.It also added a
condensedEdgeTypeMap
and acondensedNodeTypeMap
, which map all provided edge types and node types toint
to save storage space:EdgeType: Tuple[srcNodeType: str, relation: str, dstNodeType: str)] -> int
, andNodeType: str -> int
Note: You may also provide your own condensedMaps; they will be generated for you if not provided.
from gigl.common.utils.jupyter_magics import show_task_config_colored_unified_diff
show_task_config_colored_unified_diff(
f1_uri=FROZEN_TASK_CONFIG_PATH,
f2_uri=TEMPLATE_TASK_CONFIG_PATH,
f1_name='frozen_task_config.yaml',
f2_name='template_task_config.yaml'
)
2025-06-26 21:08 [INFO] : Creating symlink /tmp/tmpr3sdwsg5 -> /home/svij/GiGL/examples/toy_visual_example/template_task_config.yaml (local_fs.py:create_file_symlinks:231)
2025-06-26 21:08 [INFO] : Deleted local file at /tmp/tmpr3sdwsg5 (local_fs.py:remove_file_if_exist:64)
--- template_task_config.yaml
+++ frozen_task_config.yaml
@@ -17,6 +17,13 @@
numNeighborsToSample: 2
numPositiveSamples: 1
graphMetadata:
+ condensedEdgeTypeMap:
+ '0':
+ dstNodeType: user
+ relation: is_friends_with
+ srcNodeType: user
+ condensedNodeTypeMap:
+ '0': user
edgeTypes:
- dstNodeType: user
relation: is_friends_with
@@ -29,6 +36,35 @@
num_layers: '2'
out_dim: '128'
inferencerClsPath: gigl.src.common.modeling_task_specs.node_anchor_based_link_prediction_modeling_task_spec.NodeAnchorBasedLinkPredictionModelingTaskSpec
+sharedConfig:
+ datasetMetadata:
+ nodeAnchorBasedLinkPredictionDataset:
+ testMainDataUri: gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/
+ testNodeTypeToRandomNegativeDataUri:
+ user: gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/random_negatives/user/neighborhoods/
+ trainMainDataUri: gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/
+ trainNodeTypeToRandomNegativeDataUri:
+ user: gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/random_negatives/user/neighborhoods/
+ valMainDataUri: gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/
+ valNodeTypeToRandomNegativeDataUri:
+ user: gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/random_negatives/user/neighborhoods/
+ flattenedGraphMetadata:
+ nodeAnchorBasedLinkPredictionOutput:
+ nodeTypeToRandomNegativeTfrecordUriPrefix:
+ user: gs://svij-gigl-oss-tmp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/random_negative_rooted_neighborhood_samples/user/samples/
+ tfrecordUriPrefix: gs://svij-gigl-oss-tmp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/node_anchor_based_link_prediction_samples/samples/
+ inferenceMetadata:
+ nodeTypeToInferencerOutputInfoMap:
+ user:
+ embeddingsPath: gigl-oss-onboarding-exp-test.svij_gigl_oss_perm2.embeddings_user_gigl_test_job
+ postprocessedMetadata:
+ postProcessorLogMetricsUri: gs://svij-gigl-oss-perm/gigl_test_job/post_processor/post_processor_metrics.json
+ preprocessedMetadataUri: gs://svij-gigl-oss-perm/gigl_test_job/data_preprocess/preprocessed_metadata.yaml
+ trainedModelMetadata:
+ evalMetricsUri: gs://svij-gigl-oss-perm/gigl_test_job/trainer/models/trainer_eval_metrics.json
+ scriptedModelUri: gs://svij-gigl-oss-perm/gigl_test_job/trainer/models/scripted_model.pt
+ tensorboardLogsUri: gs://svij-gigl-oss-perm/gigl_test_job/trainer/tensorboard_logs/
+ trainedModelUri: gs://svij-gigl-oss-perm/gigl_test_job/trainer/models/model.pt
taskMetadata:
nodeAnchorBasedLinkPredictionTaskMetadata:
supervisionEdgeTypes:
# We will load the frozen task and resource configs file into an object so we can reference it in the next cells
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.gigl_resource_config import GiglResourceConfigWrapper
frozen_task_config = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
gbml_config_uri=FROZEN_TASK_CONFIG_PATH
)
resource_config: GiglResourceConfigWrapper = get_resource_config(
resource_config_uri=RESOURCE_CONFIG_PATH
)
2025-06-26 21:08 [INFO] : Skipping populating subgraph_sampling_strategy_pb_wrapper as the message is missing from the input config (gbml_config.py:__post_init__:111)
Compiling Source Docker Images#
You will need to build and push Docker images with your custom code so that individual GiGL components can leverage your code. For this experiment, we will consider the toy_visual_example specs and relevant code to be “custom code,” and we will guide you on how to build a Docker image with the code.
We will make use of scripts/build_and_push_docker_image.py
for this.
Note that this builds containers/Dockerfile.src
and containers/Dockerfile.dataflow.src
, which have instructions to COPY
the examples
folder - which contains all the source code for this example, and it has all the GiGL source code.
%%capture
# We suppress the output of this cell to avoid cluttering the notebook with logs. Remove %%%capture if you want to see the output.
from scripts.build_and_push_docker_image import build_and_push_cpu_image, build_and_push_cuda_image, build_and_push_dataflow_image
import datetime
project = resource_config.project
curr_datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Change to whereever you want to store the docker images.
DOCKER_ARTIFACT_REGISTRY = f"us-central1-docker.pkg.dev/{project}/gigl-base-images"
DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG = f"{DOCKER_ARTIFACT_REGISTRY}/gigl_dataflow_runtime:{curr_datetime}"
DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG = f"{DOCKER_ARTIFACT_REGISTRY}/gigl_cuda:{curr_datetime}"
DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG = f"{DOCKER_ARTIFACT_REGISTRY}/gigl_cpu:{curr_datetime}"
os.environ["DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG"] = DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG
os.environ["DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG"] = DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG
os.environ["DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG"] = DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG
print(DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG)
build_and_push_dataflow_image(
image_name=DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG,
)
build_and_push_cuda_image(
image_name=DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG,
)
build_and_push_cpu_image(
image_name=DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG,
)
print(f"""We built and pushed the following docker images:
- {DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG}
- {DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG}
- {DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG}
""")
Data Preprocessor#
Once we have a frozen_task_config
, the first step is to preprocess the data.
The Data Preprocessor component uses Tensorflow Transform to achieve data transformation in a distributed fashion.
Any custom preprocessing is to be defined in the preprocessor class, specified in the task config by
datasetConfig.dataPreprocessorConfig.dataPreprocessorConfigClsPath
.This class must inherit from
gigl.src.data_preprocessor.lib.data_preprocessor_config.DataPreprocessorConfig
.
In your preprocessor spec, you must implement the following 3 functions as defined by the base class DataPreprocessorConfig
:
prepare_for_pipeline
: Preparing datasets for ingestion and transformation.get_nodes_preprocessing_spec
: Defining transformation imperatives for different node types.get_edges_preprocessing_spec
: Defining transformation imperatives for different edge types.
Please take a look at toy_data_preprocessor_config.py to see how these are defined. You will note that in this case, we are not doing anything special (i.e., no feature engineering), just reading from BQ and passing through the features. We could, if we wanted, define our own preprocessing function, and replace it with build_passthrough_transform_preprocessing_fn()
defined in the code.
Input Parameters and Output Paths for Data Preprocessor#
Let’s take a quick look at what these are from our frozen config.
print("Frozen Config DataPreprocessor Information:")
print("- Data Preprocessor Config: Specifies what class to use for datapreprocessing and any arguments that might be passed in at runtime to that class")
print(textwrap.indent(str(frozen_task_config.dataset_config.data_preprocessor_config), '\t'))
print("- Preprocessed Metadata Uri: Specifies path to the preprocessed metadata file that will be generated by this component and used by subsequent components to understand and find the data that was preprocessed")
print(textwrap.indent(str(frozen_task_config.shared_config.preprocessed_metadata_uri), '\t'))
Frozen Config DataPreprocessor Information:
- Data Preprocessor Config: Specifies what class to use for datapreprocessing and any arguments that might be passed in at runtime to that class
data_preprocessor_config_cls_path: "examples.toy_visual_example.toy_data_preprocessor_config.ToyDataPreprocessorConfig"
data_preprocessor_args {
key: "bq_edges_table_name"
value: "external-snap-ci-github-gigl.public_gigl.toy_graph_homogeneous_node_anchor_lp_user-friend-user_edges_main_2025-06-02--04-42-21-UTC"
}
data_preprocessor_args {
key: "bq_nodes_table_name"
value: "external-snap-ci-github-gigl.public_gigl.toy_graph_homogeneous_node_anchor_lp_user_nodes_2025-06-02--04-42-21-UTC"
}
- Preprocessed Metadata Uri: Specifies path to the preprocessed metadata file that will be generated by this component and used by subsequent components to understand and find the data that was preprocessed
gs://svij-gigl-oss-perm/gigl_test_job/data_preprocess/preprocessed_metadata.yaml
Running Data Preprocessor and visualizing the Preprocessed Metadata#
%%capture
!python -m gigl.src.data_preprocessor.data_preprocessor \
--job_name=$JOB_NAME \
--task_config_uri=$FROZEN_TASK_CONFIG_PATH \
--resource_config_uri=$RESOURCE_CONFIG_PATH \
--custom_worker_image_uri=$DOCKER_IMAGE_DATAFLOW_RUNTIME_NAME_WITH_TAG
Upon completion of job, we will see the preprocessed metadata be populated#
preprocessed_metadata_pb = frozen_task_config.preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb
print(preprocessed_metadata_pb)
condensed_node_type_to_preprocessed_metadata {
key: 0
value {
node_id_key: "node_id"
feature_keys: "f0"
feature_keys: "f1"
tfrecord_uri_prefix: "gs://svij-gigl-oss-tmp/gigl_test_job/data_preprocess/staging/transformed_node_features_dir/user/features/"
schema_uri: "gs://svij-gigl-oss-perm/gigl_test_job/data_preprocess/node/user/tft_transform_dir/transformed_metadata/schema.pbtxt"
enumerated_node_ids_bq_table: "gigl-oss-onboarding-exp-test.svij_gigl_oss_perm2.enumerated_node_user_ids_gigl_test_job"
enumerated_node_data_bq_table: "gigl-oss-onboarding-exp-test.svij_gigl_oss_perm2.enumerated_node_user_node_features_gigl_test_job"
feature_dim: 2
transform_fn_assets_uri: "gs://svij-gigl-oss-perm/gigl_test_job/data_preprocess/node/user/tft_transform_dir/transform_fn/assets"
}
}
condensed_edge_type_to_preprocessed_metadata {
key: 0
value {
src_node_id_key: "src"
dst_node_id_key: "dst"
main_edge_info {
tfrecord_uri_prefix: "gs://svij-gigl-oss-tmp/gigl_test_job/data_preprocess/staging/transformed_edge_features_dir/user-is_friends_with-user/main/features/"
schema_uri: "gs://svij-gigl-oss-perm/gigl_test_job/data_preprocess/edge/user-is_friends_with-user/main/tft_transform_dir/transformed_metadata/schema.pbtxt"
enumerated_edge_data_bq_table: "gigl-oss-onboarding-exp-test.svij_gigl_oss_perm2.enumerated_edge_user-is_friends_with-user_main_edge_features_gigl_test_job"
transform_fn_assets_uri: "gs://svij-gigl-oss-perm/gigl_test_job/data_preprocess/edge/user-is_friends_with-user/main/tft_transform_dir/transform_fn/assets"
}
}
}
You do not have to worry about these details in code as it is all handled by the data preprocessor component and subsequent data loaders But, for the sake of understanding, we will investigate the condensed_node_type = 0 and condensed_edge_type = 0 If you remember the from the frozen config the mappings were as follows:
print("Condensed Node Type Mapping:")
print(textwrap.indent(str(frozen_task_config.graph_metadata.condensed_node_type_map), '\t'))
print("Condensed Edge Type Mapping:")
print(textwrap.indent(str(frozen_task_config.graph_metadata.condensed_edge_type_map), '\t'))
preprocessed_nodes = preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[0].tfrecord_uri_prefix
preprocessed_edges = preprocessed_metadata_pb.condensed_edge_type_to_preprocessed_metadata[0].main_edge_info.tfrecord_uri_prefix
print(f"Preprocessed Nodes are stored in: {preprocessed_nodes}")
print(f"Preprocessed Edges are stored in: {preprocessed_edges}")
Condensed Node Type Mapping:
{0: 'user'}
Condensed Edge Type Mapping:
{0: relation: "is_friends_with"
src_node_type: "user"
dst_node_type: "user"
}
Preprocessed Nodes are stored in: gs://svij-gigl-oss-tmp/gigl_test_job/data_preprocess/staging/transformed_node_features_dir/user/features/
Preprocessed Edges are stored in: gs://svij-gigl-oss-tmp/gigl_test_job/data_preprocess/staging/transformed_edge_features_dir/user-is_friends_with-user/main/features/
There is not a lot of data so we will have likely just generated one file for each of the preprocessed nodes and edges.
!gsutil ls $preprocessed_nodes && gsutil ls $preprocessed_edges
gs://svij-gigl-oss-tmp/gigl_test_job/data_preprocess/staging/transformed_node_features_dir/user/features/-00000-of-00001.tfrecord
gs://svij-gigl-oss-tmp/gigl_test_job/data_preprocess/staging/transformed_edge_features_dir/user-is_friends_with-user/main/features/-00000-of-00001.tfrecord
Subgraph Sampler#
The Subgraph Sampler receives node and edge data from the Data Preprocessor and generates k-hop localized subgraphs for each node in the graph. The purpose is to store the neighborhood of each node independently, thereby reducing the memory footprint for downstream components, as they need not load the entire graph into memory but only batches of these node neighborhoods.
To run the subgraph sampler, we use the following command:
%%capture
!python -m gigl.src.subgraph_sampler.subgraph_sampler \
--job_name=$JOB_NAME \
--task_config_uri=$FROZEN_TASK_CONFIG_PATH \
--resource_config_uri=$RESOURCE_CONFIG_PATH
Upon completion, there will be two different directories of subgraph samples. One is the main node anchor-based link prediction samples, and the other is random negative rooted neighborhood samples, which are stored in the locations specified in the frozen_config:
flattened_graph_metadata = frozen_task_config.shared_config.flattened_graph_metadata
print(flattened_graph_metadata)
node_anchor_based_link_prediction_output {
tfrecord_uri_prefix: "gs://svij-gigl-oss-tmp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/node_anchor_based_link_prediction_samples/samples/"
node_type_to_random_negative_tfrecord_uri_prefix {
key: "user"
value: "gs://svij-gigl-oss-tmp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/random_negative_rooted_neighborhood_samples/user/samples/"
}
}
The main k-hop node_anchor_based_link_prediction_samples include root nodes’ neighborhoods, positive nodes’ neighborhoods, and positive edges. These samples will be used for training.
The k-hop random_negative_rooted_neighborhood_samples (which include root nodes’ neighborhoods) serve a dual purpose: they will be used for the inferencer and as random negative samples for training.
The random negatives are used for the model to learn non-existent (negative) edges since it could overfit on just positive samples. This means it would fail to generalize well to unseen data. The negative edges are just edges chosen at random. At a large scale, this would most probably be a negative edge.
Below, we visualize the Root Node Neighborhood of 5, the Root Node Neighborhood of its pos_edge’s destination node (1), and the resulting sample for root node 5.
Since the subgraph sampler is sampling randomly here, you will get different subgraphs every time you run this.
For the purposes of example, we also provide some screenshots of what these graphs might look like:
When training, you may see a sample for node 9 as follows. Specifically, note that edge 9 --> 7
is classified as a positive edge, where SGS has sampled a 2-hop subgraph with incoming edges to both nodes 9 and 7.

Secondly, we may choose to randomly sample a rooted node neighborhood to act as a “negative sample” (i.e., in this case, we sample node 1 and can have 9 --> 1
edge be a “negative sample”).

from snapchat.research.gbml import training_samples_schema_pb2
# We will sample node with this id and visualize it. You will see positive edge marked in red and the root node with a black border.
SAMPLE_NODE_ID = 9
# We will sample random negative nodes with these ids and visualize them. You will see the root node with a black border.
SAMPLE_RANDOM_NEGATIVE_NODE_IDS = [1, 3]
from examples.toy_visual_example.visualize import GraphVisualizer, find_node_pb
print(f"The original global graph:")
GraphVisualizer.visualize_graph(original_graph_heterodata)
sample = find_node_pb(
tfrecord_uri_prefix=flattened_graph_metadata.node_anchor_based_link_prediction_output.tfrecord_uri_prefix,
node_id=SAMPLE_NODE_ID,
pb_type=training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample
)
print(f"Node anchor prediction sample for node {SAMPLE_NODE_ID}:")
GraphVisualizer.plot_graph(sample)
for random_negative_node_id in SAMPLE_RANDOM_NEGATIVE_NODE_IDS:
random_negative_sample = find_node_pb(
tfrecord_uri_prefix=flattened_graph_metadata.node_anchor_based_link_prediction_output.node_type_to_random_negative_tfrecord_uri_prefix["user"],
node_id=random_negative_node_id,
pb_type=training_samples_schema_pb2.RootedNodeNeighborhood
)
print(f"Random negative sample for node {random_negative_node_id}:")
GraphVisualizer.plot_graph(random_negative_sample)
The original global graph:

Searching for node 9 in gs://svij-gigl-oss-tmp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/node_anchor_based_link_prediction_samples/samples/*.tfrecord
Node anchor prediction sample for node 9:

Searching for node 1 in gs://svij-gigl-oss-tmp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/random_negative_rooted_neighborhood_samples/user/samples/*.tfrecord
Random negative sample for node 1:

Searching for node 3 in gs://svij-gigl-oss-tmp/gigl_test_job/subgraph_sampler/node_anchor_based_link_prediction/random_negative_rooted_neighborhood_samples/user/samples/*.tfrecord
Random negative sample for node 3:

Split Generator#
The Split Generator reads localized subgraph samples produced by the Subgraph Sampler and executes the user-specified split strategy logic to split the data into training, validation, and test sets. Several standard configurations of SplitStrategy and corresponding Assigner classes are already implemented at a GiGL platform level: transductive node classification, inductive node classification, and transductive link prediction split routines. For more information on split strategies in Graph Machine Learning, check out these resources:
Graph Link Prediction (relevant for explaining transductive vs inductive).
In this example, we are using the transductive strategy as specified in our frozen_config:
print(frozen_task_config.dataset_config.split_generator_config)
split_strategy_cls_path: "splitgenerator.lib.split_strategies.TransductiveNodeAnchorBasedLinkPredictionSplitStrategy"
assigner_cls_path: "splitgenerator.lib.assigners.TransductiveEdgeToLinkSplitHashingAssigner"
assigner_args {
key: "seed"
value: "42"
}
assigner_args {
key: "test_split"
value: "0.2"
}
assigner_args {
key: "train_split"
value: "0.7"
}
assigner_args {
key: "val_split"
value: "0.1"
}
For transductive, at training time, it uses training message edges to predict training supervision edges. At validation time, the training message edges and training supervision edges are used to predict the validation edges and then all 3 are used to predict test edges. Below is the command to run split generator:
%%capture
!python -m gigl.src.split_generator.split_generator \
--job_name=$JOB_NAME \
--task_config_uri=$FROZEN_TASK_CONFIG_PATH \
--resource_config_uri=$RESOURCE_CONFIG_PATH
Upon completion, there will be 3 folders for train,test, and val. Each of them contains the protos for the positive and negaitve samples. The path for these folders is specified in the following location in the frozen_config:
dataset_metadata = frozen_task_config.shared_config.dataset_metadata
print(dataset_metadata)
node_anchor_based_link_prediction_dataset {
train_main_data_uri: "gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/"
test_main_data_uri: "gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/"
val_main_data_uri: "gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/"
train_node_type_to_random_negative_data_uri {
key: "user"
value: "gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/random_negatives/user/neighborhoods/"
}
val_node_type_to_random_negative_data_uri {
key: "user"
value: "gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/random_negatives/user/neighborhoods/"
}
test_node_type_to_random_negative_data_uri {
key: "user"
value: "gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/random_negatives/user/neighborhoods/"
}
}
We can visualize the train,test, and val samples for all nodes. Note, for this specific task config setup, not all val amd test samples will have positive labels. This is because edges are randomly assigned into “train”, “val”, or “test” buckets independent of whether or not they are supervision edges. Thus, although at scale this setting is okay i.e. with large data and a large batch size each batch will have some supervision edges, it can be a case that certain batches dont have any supervision edges. Thus, your val/test loops and early stopping logic may need to be carefully designed.
for node_id in range(original_graph_heterodata.num_nodes):
print(f"Node anchor prediction sample for node {node_id}:")
sample_train = find_node_pb(
tfrecord_uri_prefix=dataset_metadata.node_anchor_based_link_prediction_dataset.train_main_data_uri,
node_id=node_id,
pb_type=training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample
)
sample_val = find_node_pb(
tfrecord_uri_prefix=dataset_metadata.node_anchor_based_link_prediction_dataset.val_main_data_uri,
node_id=node_id,
pb_type=training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample
)
sample_test = find_node_pb(
tfrecord_uri_prefix=dataset_metadata.node_anchor_based_link_prediction_dataset.test_main_data_uri,
node_id=node_id,
pb_type=training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample
)
if sample_train:
print(f"Train sample for node {node_id}: ")
GraphVisualizer.plot_graph(sample_train)
else:
print(f"No train sample found for node {node_id}.")
if sample_val:
print(f"Validation sample for node {node_id}:")
GraphVisualizer.plot_graph(sample_val)
else:
print(f"No validation sample found for node {node_id}.")
if sample_test:
print(f"Test sample for node {node_id}:")
GraphVisualizer.plot_graph(sample_test)
else:
print(f"No test sample found for node {node_id}.")
Node anchor prediction sample for node 0:
Searching for node 0 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 0 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 0 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 0:

Validation sample for node 0:

Test sample for node 0:

Node anchor prediction sample for node 1:
Searching for node 1 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 1 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 1 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 1:

Validation sample for node 1:

Test sample for node 1:

Node anchor prediction sample for node 2:
Searching for node 2 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 2 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 2 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 2:

Validation sample for node 2:

Test sample for node 2:

Node anchor prediction sample for node 3:
Searching for node 3 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 3 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 3 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 3:

Validation sample for node 3:

Test sample for node 3:

Node anchor prediction sample for node 4:
Searching for node 4 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 4 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 4 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 4:

Validation sample for node 4:

Test sample for node 4:

Node anchor prediction sample for node 5:
Searching for node 5 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 5 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 5 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 5:

Validation sample for node 5:

Test sample for node 5:

Node anchor prediction sample for node 6:
Searching for node 6 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 6 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 6 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 6:

Validation sample for node 6:

Test sample for node 6:

Node anchor prediction sample for node 7:
Searching for node 7 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 7 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 7 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 7:

Validation sample for node 7:

Test sample for node 7:

Node anchor prediction sample for node 8:
Searching for node 8 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 8 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 8 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 8:

Validation sample for node 8:

Test sample for node 8:

Node anchor prediction sample for node 9:
Searching for node 9 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 9 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 9 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 9:

Validation sample for node 9:

Test sample for node 9:

Node anchor prediction sample for node 10:
Searching for node 10 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 10 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 10 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 10:

Validation sample for node 10:

Test sample for node 10:

Node anchor prediction sample for node 11:
Searching for node 11 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 11 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 11 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 11:

Validation sample for node 11:

Test sample for node 11:

Node anchor prediction sample for node 12:
Searching for node 12 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 12 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 12 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 12:

Validation sample for node 12:

Test sample for node 12:

Node anchor prediction sample for node 13:
Searching for node 13 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/train/main_samples/samples/*.tfrecord
Searching for node 13 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/val/main_samples/samples/*.tfrecord
Searching for node 13 in gs://svij-gigl-oss-tmp/gigl_test_job/split_generator/test/main_samples/samples/*.tfrecord
Train sample for node 13:

Validation sample for node 13:

Test sample for node 13:

At this point, we have our graph data samples ready to be processed by the trainer and inferencer components. These components will extract representations/embeddings by learning contextual information for the specified task.