Task Config Guide#
The task config specifies task-related configurations - guiding the behavior of components according to the needs of your machine learning task.
Whenever we say “task config” we are talking about an instance off
snapchat.research.gbml.gbml_config_pb2.GbmlConfig
This is a protobuff class whose definition can be found in gbml_config.proto:
syntax = "proto3";
package snapchat.research.gbml;
import "snapchat/research/gbml/graph_schema.proto";
import "snapchat/research/gbml/flattened_graph_metadata.proto";
import "snapchat/research/gbml/dataset_metadata.proto";
import "snapchat/research/gbml/trained_model_metadata.proto";
import "snapchat/research/gbml/inference_metadata.proto";
import "snapchat/research/gbml/postprocessed_metadata.proto";
import "snapchat/research/gbml/subgraph_sampling_strategy.proto";
/*
TODO: document all protos with comments.
*/
message GbmlConfig {
// Indicates the training task specification and metadata for the config.
message TaskMetadata {
oneof task_metadata {
NodeBasedTaskMetadata node_based_task_metadata = 1;
NodeAnchorBasedLinkPredictionTaskMetadata node_anchor_based_link_prediction_task_metadata = 2;
LinkBasedTaskMetadata link_based_task_metadata = 3;
}
message NodeBasedTaskMetadata {
repeated string supervision_node_types = 1;
}
message NodeAnchorBasedLinkPredictionTaskMetadata {
repeated EdgeType supervision_edge_types = 1;
}
message LinkBasedTaskMetadata {
repeated EdgeType supervision_edge_types = 1;
}
}
message SharedConfig {
// Uri where DataPreprocessor generates the PreprocessedMetadata proto.
string preprocessed_metadata_uri = 1;
// FlattenedGraphMetadata message, which designates locations of GraphFlat outputs.
FlattenedGraphMetadata flattened_graph_metadata = 2;
// DatasetMetadata message, which designates location of SplitGenerator outputs.
DatasetMetadata dataset_metadata = 3;
// TrainedModelMetadata message, which designates location of Trainer outputs.
TrainedModelMetadata trained_model_metadata = 4;
// InferenceMetadata message, which designates location of Inferencer outputs.
InferenceMetadata inference_metadata = 5;
// PostProcessedMetadata message, which designates location of PostProcessor outputs.
PostProcessedMetadata postprocessed_metadata = 12;
map<string, string> shared_args = 6;
// is the graph directed or undirected (bidirectional)
bool is_graph_directed = 7;
// to skip training or not (inference only)
bool should_skip_training = 8;
// If set to true, will skip automatic clean up of temp assets
// Useful if you are running hyper param tuning jobs and dont want to continuously
// run the whole pipeline
bool should_skip_automatic_temp_asset_cleanup = 9;
// to skip inference or not (for training only jobs)
bool should_skip_inference = 10;
// If set, we will not compute or export model metrics like MRR, etc
// Has a side effect if should_skip_training is set as well to result in
// not generating training samples and only RNNs needed for inference.
bool should_skip_model_evaluation = 11;
// If set to true, will include isolated nodes in training data
// As isolated nodes do not have positive neighbors, self loop will be added
// SGS outputs training samples including isolated nodes, trainer adds self loops in training subgraphs
bool should_include_isolated_nodes_in_training = 13;
}
// Contains config related to generating training data for a GML task.
message DatasetConfig {
message DataPreprocessorConfig {
// Uri pointing to user-written DataPreprocessorConfig class definition.
string data_preprocessor_config_cls_path = 1;
// Arguments to instantiate concrete DataPreprocessorConfig instance with.
map<string, string> data_preprocessor_args = 2;
}
message SubgraphSamplerConfig {
// number of hops for subgraph sampler to include
uint32 num_hops = 1 [deprecated=true];
// num_neighbors_to_sample indicates the max number of neighbors to sample for each hop
// num_neighbors_to_sample can be set to -1 to indicate no sampling (include all neighbors)
int32 num_neighbors_to_sample = 2 [deprecated=true];
// num hops and num neighbors to sample is deprecated in favor of neighbor_sampling_strategy.
// Used to specify how the graphs which are used for message passing are constructed
SubgraphSamplingStrategy subgraph_sampling_strategy = 10;
// number of positive samples (1hop) used in NodeAnchorBasedLinkPredictionTask
// as part of loss computation. It cannot be 0. And it's recommended to be larger
// than 1 due to the split filtering logic in split generator, to guarantee most samples to
// have at least one positive for it to not be excluded in training.
uint32 num_positive_samples = 3;
// (deprecated)
// number of hard negative samples (3,4hops) used in NodeAnchorBasedLinkPredictionTask
// also used in loss computation. Random negatives will always be used even when there
// are no hard negatives
// uint32 num_hard_negative_samples = 4;
// Arguments for experimental_flags, can be permutation_strategy: 'deterministic' or 'non-deterministic'
map<string, string> experimental_flags = 5;
// max number of training samples (i.e. nodes to store as main samples for training)
// If this is not provided or is set to 0, all nodes will be included for training
uint32 num_max_training_samples_to_output = 6;
// number of user defined positive samples. Used in NodeAnchorBasedLinkPredictionTask
// as part of loss computation.
// If `num_user_defined_positive_samples` is specified `num_positive_samples` will be ignored as
// positive samples will only be drawn from user defined positive samples.
uint32 num_user_defined_positive_samples = 7 [deprecated=true];
// number of user defined negative samples.
// Treated as hard negative samples. Used in NodeAnchorBasedLinkPredictionTask
// Also used in loss computation. Random negatives will always be used even when there
// are no user defined hard negatives
uint32 num_user_defined_negative_samples = 8 [deprecated=true];
// If specified, intention is to run ingestion into graphDB for subgraph sampler
GraphDBConfig graph_db_config = 9;
}
message SplitGeneratorConfig {
// Module path to concrete SplitStrategy instance.
string split_strategy_cls_path = 1;
// Arguments to instantiate concrete SplitStrategy instance with.
map<string, string> split_strategy_args = 2;
// Module path to concrete Assigner instance
string assigner_cls_path = 3;
// Arguments to instantiate concrete Assigner instance with.
map<string, string> assigner_args = 4;
}
DataPreprocessorConfig data_preprocessor_config = 1;
SubgraphSamplerConfig subgraph_sampler_config = 2;
SplitGeneratorConfig split_generator_config = 3;
}
// Generic Configuration for a GraphDB connection.
message GraphDBConfig {
// Python class path pointing to user-written
// `BaseIngestion`` class definition. e.g. `my.team.graph_db.BaseInjectionImpl`.
// This class is currently, as an implementation detail, used for injestion only.
// We document this *purely* for information purposes and may change the implementation at any time.
string graph_db_ingestion_cls_path = 1;
// Arguments to instantiate concrete BaseIngestion instance with.
map<string, string> graph_db_ingestion_args = 2;
// General arguments required for graphDB (graph space, port, etc.)
// These are passed to both the Python and Scala implementations.
map<string, string> graph_db_args = 3;
// If provided, then an implementation of a `DBClient[DBResult]` Scala class
// for a GraphDB.
// Intended to be used to inject specific implementations at runtime.
// The object constructed from this is currently, as an implementation detail, used for sampling only.
// We document this *purely* for information purposes and may change the implementation at any time.
GraphDBServiceConfig graph_db_sampler_config = 4;
// Scala-specific configuration.
message GraphDBServiceConfig {
// Scala absolute class path pointing to an implementation of `DBClient[DBResult]`
// e.g. `my.team.graph_db.DBClient`.
string graph_db_client_class_path = 1;
}
}
message TrainerConfig {
// (deprecated)
// Uri pointing to user-written BaseTrainer class definition. Used for the subgraph-sampling-based training process.
string trainer_cls_path = 1;
// Arguments to parameterize training process with.
map<string, string> trainer_args = 2;
// Specifies how to execute training
oneof executable {
// Path pointing to trainer class definition.
string cls_path = 100;
// Command to use for launching trainer job
string command = 101;
}
// Weather to log to tensorboard or not (defaults to false)
bool should_log_to_tensorboard = 12;
}
message InferencerConfig {
map<string, string> inferencer_args = 1;
// (deprecated)
// Path to modeling task spec class path to construct model for inference. Used for the subgraph-sampling-based inference process.
string inferencer_cls_path = 2;
// Specifies how to execute inference
oneof executable {
// Path pointing to inferencer class definition.
string cls_path = 100;
// Command to use for launching inference job
string command = 101;
}
// Optional. If set, will be used to batch inference samples to a specific size before call for inference is made
// Defaults to setting in python/gigl/src/inference/gnn_inferencer.py
uint32 inference_batch_size = 5;
}
message PostProcessorConfig {
map<string, string> post_processor_args = 1;
string post_processor_cls_path = 2;
}
message MetricsConfig {
string metrics_cls_path = 1;
map<string, string> metrics_args = 2;
}
message ProfilerConfig {
bool should_enable_profiler = 1;
string profiler_log_dir = 2;
map<string, string> profiler_args = 3;
}
TaskMetadata task_metadata = 1;
GraphMetadata graph_metadata = 2;
SharedConfig shared_config = 3;
DatasetConfig dataset_config = 4;
TrainerConfig trainer_config = 5;
InferencerConfig inferencer_config = 6;
PostProcessorConfig post_processor_config = 9;
MetricsConfig metrics_config = 7;
ProfilerConfig profiler_config = 8;
map<string, string> feature_flags = 10;
}
Just like resource config, the values to instantiate this proto class are usually provided
as a .yaml
file. Most components accept the task config as an argument
--task_config_uri
- i.e. a gigl.common.Uri
pointing to a task_config.yaml
file.
Example#
We will use the MAG240M task config to walk you through what a config may look like.
Full task config for reference:
# ========
# TaskMetadata:
# Specifies the task we are going to perform on the graph.
taskMetadata:
nodeAnchorBasedLinkPredictionTaskMetadata:
# Specifying that we will perform node anchor based link prediction on edge of type: paper_or_author -> references -> paper_or_author
supervisionEdgeTypes:
- srcNodeType: paper_or_author
relation: references
dstNodeType: paper_or_author
# ========
# GraphMetadata:
# Specifies what edge and node types are present in the graph.
# Note all the edge / node types here should be referenced in the preprocessor_config
graphMetadata:
edgeTypes:
- dstNodeType: paper_or_author
relation: references
srcNodeType: paper_or_author
nodeTypes:
- paper_or_author
# ========
# SharedConfig:
# Specifies some extra metadata about the graph structure management of orchestration.
sharedConfig:
isGraphDirected: True
shouldSkipAutomaticTempAssetCleanup: true # Should we skip cleaning up the temporary assets after the run is complete?
# ========
# DatasetConfig:
# Specifies information about the dataset. How to access it, how to process it, and how to sample subgraphs from it.
datasetConfig:
dataPreprocessorConfig:
dataPreprocessorConfigClsPath: examples.MAG240M.preprocessor_config.Mag240DataPreprocessorConfig
# our implementation takes no runtime arguments; if provided these are passed to the constructor off dataPreprocessorConfigClsPath
# dataPreprocessorArgs:
subgraphSamplerConfig:
numHops: 2 # Each subgraph that is computed will be of 2 hops
numNeighborsToSample: 15 # And, we will sample 10 neighbors at each hop for each node
numUserDefinedPositiveSamples: 1 # We will sample 1 positive sample per anchor node
splitGeneratorConfig:
assignerArgs:
seed: '42'
test_split: '0.2'
train_split: '0.7'
val_split: '0.1'
# Since the positive labels are user defined we use the following setup.
# More assigner and split strategies can be found in splitgenerator.lib.assigners and
# splitgenerator.lib.split_strategies respectively.
assignerClsPath: splitgenerator.lib.assigners.UserDefinedLabelsEdgeToLinkSplitHashingAssigner
splitStrategyClsPath: splitgenerator.lib.split_strategies.UserDefinedLabelsNodeAnchorBasedLinkPredictionSplitStrategy
# ========
# TrainerConfig:
# Specifies the training configuration. This includes the trainer class, the arguments to pass to it
# The trainer class is responsible for training the model, and the arguments are passed to its constructor.
trainerConfig:
# GiGL provides a basic implementation of a NABLP trainer; customers are encouraged to extend this class to suit their needs.
trainerClsPath: gigl.src.common.modeling_task_specs.NodeAnchorBasedLinkPredictionModelingTaskSpec
trainerArgs: # The following arguments are passed to trainerClsPath's constructor. See class implementation for more details.
early_stop_patience: '5'
early_stop_criterion: 'loss'
main_sample_batch_size: '512' # Reduce batch size if Cuda OOM. Note that train/validation/test loss definition is associated with this batch size.
num_test_batches: '400' # Increase this number to get more stable test loss
num_val_batches: '192'
random_negative_sample_batch_size: '512'
random_negative_sample_batch_size_for_evaluation: '1000' # The validation/test MRR and hit rates definitions are associated with this batch size.
val_every_num_batches: '100' # Trains the model for 100 batches, evaluate it, and mark it as the best checkpoint.
# More data loaders prefetch more data into memory, which significantly saves data read and preprocess time.
# However, it also significantly increases CPU memory consumption and could lead to CPU memory OOM.
# The CPU memory consumption depends on both the number of data loaders and the batch size.
train_main_num_workers: '10'
train_random_negative_num_workers: '10'
val_main_num_workers: '4'
val_random_negative_num_workers: '4'
test_main_num_workers: '8'
test_random_negative_num_workers: '8'
# ========
# InferencerConfig:
# specifies the inference configuration. This includes the inferencer class, the arguments to pass to it
# The inferencer class is responsible for running inference on the model, and the arguments are passed to its constructor.
inferencerConfig:
# inferencerArgs: We don't need to pass any special arguments for inferencer
# Note: The inferencerClsPath is the same as the trainerClsPath
# This is because NodeAnchorBasedLinkPredictionModelingTaskSpec implements both BaseTrainer (interface class needs to implement for training)
# and BaseInferencer (interface class needs to implement for inference). See their respective definitions for more information:
# - gigl.src.training.v1.lib.base_trainer.BaseTrainer
# - gigl.src.inference.v1.lib.base_inferencer.BaseInferencer
inferencerClsPath: gigl.src.common.modeling_task_specs.NodeAnchorBasedLinkPredictionModelingTaskSpec
# ========
GraphMetadata#
We specify what are all the nodes and edges in the graph. In this case we have one node type: paper_or_author
. And,
one edge type: (paper_or_author, references, paper_or_author)
Note: In this example we have converted the hetrogeneous MAG240M dataset to a homogeneous one with just one edge and one node; which we will be doing self supervised learning on.
# Specifies what edge and node types are present in the graph.
# Note all the edge / node types here should be referenced in the preprocessor_config
graphMetadata:
edgeTypes:
- dstNodeType: paper_or_author
relation: references
srcNodeType: paper_or_author
nodeTypes:
- paper_or_author
TaskMetadata#
Now we specify what type of learning task we want to do. In this case we want to leverage Node Anchor Based Link
Prediction to do self supervised learning on the edge: (paper_or_author, references, paper_or_author)
. Thus, we are
using the NodeAnchorBasedLinkPredictionTaskMetadata
task.
# Specifies the task we are going to perform on the graph.
taskMetadata:
nodeAnchorBasedLinkPredictionTaskMetadata:
# Specifying that we will perform node anchor based link prediction on edge of type: paper_or_author -> references -> paper_or_author
supervisionEdgeTypes:
- srcNodeType: paper_or_author
relation: references
dstNodeType: paper_or_author
Note
An example of NodeBasedTaskMetadata
can be found in python/gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml
DatasetConfig#
We create the dataset that we will be using. In this example we will be using the class dataPreprocessorConfigClsPath
to read and preprocess the data. See Preprocessor Guide.
Once we have the data preprocessed, we will be tabularizing the data with the use of
Subgraph SamplerSpecifically, for each node we will be sampling their
numHops
neighborhood, where each hop will sample numNeighborsToSample
neighbors. As well, we will be sampling
numUserDefinedPositiveSamples
positive samples and their respective neighborhood using numHops
and
numNeighborsToSample
.
Subsequently, we will be creating test/train/val splits based on the %’s specified, using Split Generator
# Specifies information about the dataset. How to access it, how to process it, and how to sample subgraphs from it.
datasetConfig:
dataPreprocessorConfig:
dataPreprocessorConfigClsPath: examples.MAG240M.preprocessor_config.Mag240DataPreprocessorConfig
# our implementation takes no runtime arguments; if provided these are passed to the constructor off dataPreprocessorConfigClsPath
# dataPreprocessorArgs:
subgraphSamplerConfig:
numHops: 2 # Each subgraph that is computed will be of 2 hops
numNeighborsToSample: 15 # And, we will sample 10 neighbors at each hop for each node
numUserDefinedPositiveSamples: 1 # We will sample 1 positive sample per anchor node
splitGeneratorConfig:
assignerArgs:
seed: '42'
test_split: '0.2'
train_split: '0.7'
val_split: '0.1'
# Since the positive labels are user defined we use the following setup.
# More assigner and split strategies can be found in splitgenerator.lib.assigners and
# splitgenerator.lib.split_strategies respectively.
assignerClsPath: splitgenerator.lib.assigners.UserDefinedLabelsEdgeToLinkSplitHashingAssigner
splitStrategyClsPath: splitgenerator.lib.split_strategies.UserDefinedLabelsNodeAnchorBasedLinkPredictionSplitStrategy
TrainerConfig#
The class specified by trainerClsPath
will be initialized and all the arguments specified in trainerArgs
will be
directly passed as **kwargs
to your trainer class. Thes only requirement is the trainer class implement the protocol
defined @ gigl.src.training.v1.lib.base_trainer.BaseTrainer
.
Some common sense pre-configured trainer implementations can be found in
gigl.src.common.modeling_task_specs
. Although, you are recommended to implement your own.
# Specifies the training configuration. This includes the trainer class, the arguments to pass to it
# The trainer class is responsible for training the model, and the arguments are passed to its constructor.
trainerConfig:
# GiGL provides a basic implementation of a NABLP trainer; customers are encouraged to extend this class to suit their needs.
trainerClsPath: gigl.src.common.modeling_task_specs.NodeAnchorBasedLinkPredictionModelingTaskSpec
trainerArgs: # The following arguments are passed to trainerClsPath's constructor. See class implementation for more details.
early_stop_patience: '5'
early_stop_criterion: 'loss'
main_sample_batch_size: '512' # Reduce batch size if Cuda OOM. Note that train/validation/test loss definition is associated with this batch size.
num_test_batches: '400' # Increase this number to get more stable test loss
num_val_batches: '192'
random_negative_sample_batch_size: '512'
random_negative_sample_batch_size_for_evaluation: '1000' # The validation/test MRR and hit rates definitions are associated with this batch size.
val_every_num_batches: '100' # Trains the model for 100 batches, evaluate it, and mark it as the best checkpoint.
# More data loaders prefetch more data into memory, which significantly saves data read and preprocess time.
# However, it also significantly increases CPU memory consumption and could lead to CPU memory OOM.
# The CPU memory consumption depends on both the number of data loaders and the batch size.
train_main_num_workers: '10'
train_random_negative_num_workers: '10'
val_main_num_workers: '4'
val_random_negative_num_workers: '4'
test_main_num_workers: '8'
test_random_negative_num_workers: '8'
InferencerConfig#
Similar to Trainer, the class specified by inferencerClsPath
will be initialized and all arguments specified in
inferencerArgs
will be directly passed in **kwargs
to your inferencer class. The only requirement is the inferencer
class implement the protocol defined @ gigl.src.inference.v1.lib.base_inferencer.BaseInferencer
# specifies the inference configuration. This includes the inferencer class, the arguments to pass to it
# The inferencer class is responsible for running inference on the model, and the arguments are passed to its constructor.
inferencerConfig:
# inferencerArgs: We don't need to pass any special arguments for inferencer
# Note: The inferencerClsPath is the same as the trainerClsPath
# This is because NodeAnchorBasedLinkPredictionModelingTaskSpec implements both BaseTrainer (interface class needs to implement for training)
# and BaseInferencer (interface class needs to implement for inference). See their respective definitions for more information:
# - gigl.src.training.v1.lib.base_trainer.BaseTrainer
# - gigl.src.inference.v1.lib.base_inferencer.BaseInferencer
inferencerClsPath: gigl.src.common.modeling_task_specs.NodeAnchorBasedLinkPredictionModelingTaskSpec