Trainer#
The Trainer component is the entry point for model training in GiGL. It supports both the legacy tabularized pipeline and the newer in-memory subgraph sampling path.
Input#
job_name (AppliedTaskIdentifier): which uniquely identifies an end-to-end task.
task_config_uri (Uri): Path which points to a “frozen”
GbmlConfigproto yaml file - Can be either manually created, orconfig_populatorcomponent (recommended approach) can be used which can generate this frozen config from a template config.resource_config_uri (Uri): Path which points to a
GiGLResourceConfigyaml
What does it do?#
The Trainer undertakes the following actions:
Reads the frozen
GbmlConfigand resource config.Cleans existing trainer output paths so retries do not mix old and new assets.
Chooses the training backend:
Legacy tabularized path when
featureFlags.should_run_glt_backendis not enabled.In-memory subgraph sampling path when
featureFlags.should_run_glt_backendisTrue.
Launches the selected training runtime and persists output metadata such as model parameters and offline metrics.
In-Memory Subgraph Sampling Path#
In the in-memory path, the Trainer launches the distributed runtime used for live neighborhood sampling. At a high level, that runtime:
launches the user-provided training command from
trainerConfig.command,uses Data Preprocessor outputs to build a
DistDatasetorRemoteDistDataset,samples neighborhoods online during training instead of consuming precomputed sampled subgraphs.
For link prediction, the reference training loops under examples/link_prediction use:
DistABLPLoaderfor anchor-based link prediction batches,DistNeighborLoaderfor random negative batches.
Legacy path#
In the legacy path, the Trainer consumes the outputs of Split Generator and delegates to the v1 trainer stack.
How do I run it?#
Import GiGL
from gigl.src.training.trainer import Trainer
from gigl.common import UriFactory
from gigl.src.common.types import AppliedTaskIdentifier
trainer = Trainer()
trainer.run(
applied_task_identifier=AppliedTaskIdentifier("sample_job_name"),
task_config_uri=UriFactory.create_uri("gs://MY TEMP ASSETS BUCKET/frozen_task_config.yaml"),
resource_config_uri=UriFactory.create_uri("gs://MY TEMP ASSETS BUCKET/resource_config.yaml"),
)
Note: If you are training on VertexAI and using a custom class, you will have to provide a docker image (Either
cuda_docker_uri for GPU training or cpu_docker_uri for CPU training.)
For in-memory subgraph sampling training, the component currently supports Vertex AI execution. The example training
scripts under examples/link_prediction can still be run directly for local experimentation with an already frozen task
config.
Command Line
python -m \
gigl.src.training.trainer \
--job_name="sample_job_name" \
--task_config_uri="gs://MY TEMP ASSETS BUCKET/frozen_task_config.yaml" \
--resource_config_uri="gs://MY TEMP ASSETS BUCKET/resource_config.yaml"
Output#
After the training process finishes:
The Trainer saves the trained model’s
state_dictat specified location (trainedModelUrifield ofsharedConfig.trainedModelMetadata).The trainer logs training metrics to
trainingLogsUrifield ofsharedConfig.trainedModelMetadata. To view the metrics on your local, you can run the command:tensorboard --logdir gs://tensorboard_logs_uri_here
Examples#
Reference in-memory training implementations:
Custom Usage#
The customization point depends on the backend:
Legacy path: training logic is provided through a
BaseTrainerimplementation.In-memory path: training logic is provided by the user command referenced in
trainerConfig.command, such as the example scripts underexamples/link_prediction.
Other#
Torch Profiler#
You can profile trainer performance metrics, such as gpu/cpu utilization by adding below to task_config.yaml
profilerConfig:
should_enable_profiler: true
profiler_log_dir: gs://path_to_my_bucket (or a local dir)
profiler_args:
wait:'0'
with_stack: 'True'
Monitoring and logging#
Once the trainer component starts, the training process can be monitored via the gcloud console under Vertex AI Custom
Jobs (https://console.cloud.google.com/vertex-ai/training/custom-jobs?project=<project_name_here>). You can also view
the job name, status, jobspec, and more using gcloud ai custom-jobs list --project <project_name_here>
On the Vertex AI UI, you can see all the information like machine/acceleratior information, CPU Utilization, GPU utiliization, Network data etc. Here, you will also find the “View logs” tab, which will open the Stackdriver for your job which logs everything from your modeling task spec as the training progresses in real time.
If you would like to view the logs locally, you can also use:
gcloud ai custom-jobs stream-logs <custom job ID> --project=<project_name_here> --region=<region here>.
Parameters#
We provide some base class implementations for training. See:
gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.pygigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.pygigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py
**** Note: many training/model params require dep on using the right model / training setup i.e. specific configurations may not be supported - see individual implementations to understand how each param is used. Training specs are fully customizable - these are only examples
The v1 modeling-task-spec implementations provide runtime arguments similar to below. We present examples of the args
for node_anchor_based_link_prediction_modeling_task_spec.py here. These are most relevant to the legacy path;
in-memory training scripts typically define their own runtime arguments in trainerArgs.
Training environment parameters (number of workers for different dataloaders)
train_main_num_workerstrain_random_negative_num_workersval_main_num_workersval_random_negative_num_workerstest_main_num_workerstest_random_negative_num_workers
Note that training involves multiple dataloaders simultaneously. Take care to specify these parameters in a way which avoids overburdening your machine. It is recommended to specify
(train_main_sample_num_workers + train_random_sample_num_workers + val_main_sample_num_workers + val_random_sample_num_workers < num_cpus), and(test_main_sample_num_workers + test_random_sample_num_workers < num_cpus)to avoid training stalling due to contention.Modifying the GNN model:
Specified by arg
gnn_model_class_pathSome Sample GNN models are defined here and initialized in the
init_modelfunction in ModelingTaskSpec. When trying different GNN models, it is recommended to also include the new GNN architectures under the same file and declare them as is currently done. This cannot currently be done from the defaultGbmlConfigyaml.
Non Exhaustive list of Model parameters:
hidden_dim: dimension of the hidden layersnum_layers: number of layers in the GNNout_channels: dimension of the output embeddingsshould_l2_normalize_embedding_layer_output: whether apply L2 normalization on the output embeddings
Non Exhaustive list of Training parameters:
num_headsval_every_num_batches: validation frequence per training batchesnum_val_batches: number of validation batchesnum_test_batches: number of testing batchesoptim_class_path: defaults to “torch.optim.Adam”optim_lr: learning rate of the optimizeroptim_weight_decay: weight decay of the optimizerclip_grad_normlr_scheduler_name: defaults to “torch.optim.lr_scheduler.ConstantLR”factor: param for lr schedulertotal_iters: param for lr schedulermain_sample_batch_size: training batch sizerandom_negative_sample_batch_size: random negative sample batch size for trainingrandom_negative_sample_batch_size_for_evaluation: random negative sample batch size for evaluationtrain_main_num_workersval_main_num_workerstest_main_num_workerstrain_random_negative_num_workersval_random_negative_num_workerstest_random_negative_num_workersearly_stop_criterion: defaults to “loss”early_stop_patience: patience for earlystoppingtask_path: python class path to supported training tasks i.e. Retrievalgigl.src.common.models.layers.task.Retrieval; see gigl.src.common.models.layers.task.py for more infosoftmax_temp: temperature parameter in thesoftmaxlossshould_remove_accidental_hits
Background for distributed training#
Trainer currently uses PyTorch distributed training abstractions to enable multi-node and multi-GPU training. Some useful terminology and links to learn about these abstractions below.
WORLD: Group of processes/workers that are used for distributed training.
WORLD_SIZE: The number of processes/workers in the distributed training WORLD.
RANK: The unique id (usually index) of the process/worker in the distributed training WORLD.
Data loader worker: A worker used specifically for loading data; if the dataloader worker is utilizing the same thread/process as a worker in distributed training WORLD, then we may incur blocking execution of training, resulting in slowdowns.
Distributed Data Parallel: Pytorch’s version of Data parallalism across different processes (could even be processes on different machines), to speed up traiing on large datasets.
TORCH.DISTRIBUTED package: A torch package containing tools for distributed communication and trainings.
Defines backends for distributed communication like
glooandnccl- as a ML practitioner you should not worry about how these work, but important to know what devices and collective functions they support.Contains “Collective functions” like
torch.distributed.broadcast,torch.distributed.all_gather, et al. which allow communication of tensors across the WORLD.