gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec#
Attributes#
Classes#
Template Simple Training Spec that uses GraphSAGE for Node Anchor Based Link Prediction with DDP support. |
Module Contents#
- class gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec.GraphSageTemplateTrainerSpec(**kwargs)[source]#
Bases:
gigl.src.training.v1.lib.base_trainer.BaseTrainer
,gigl.src.inference.v1.lib.base_inferencer.NodeAnchorBasedLinkPredictionBaseInferencer
Template Simple Training Spec that uses GraphSAGE for Node Anchor Based Link Prediction with DDP support. Arguments are to be passed in via trainerArgs in GBML Config.
- Parameters:
hidden_dim (int) – Hidden dimension to use for the model (default: 64)
num_layers (int) – Number of layers to use for the model (default: 2)
out_channels (int) – Output channels to use for the model (default: 64)
validate_every_n_batches (int) – Number of batches to validate after (default: 20)
num_val_batches (int) – Number of batches to validate on (default: 10)
num_test_batches (int) – Number of batches to test on (default: 100)
early_stop_patience (int) – Number of consecutive checks without improvement to trigger early stopping (default: 3)
num_epochs (int) – Number of epochs to train the model for (default: 5)
optim_lr (float) – Learning rate to use for the optimizer (default: 0.001)
main_sample_batch_size (int) – Batch size to use for the main samples (default: 256)
random_negative_batch_size (int) – Batch size to use for the random negative samples (default: 64)
train_main_num_workers (int) – Number of workers to use for the train main dataloader (default: 2)
val_main_num_workers (int) – Number of workers to use for the val main dataloader (default: 1)
- eval(gbml_config_pb_wrapper, device)[source]#
Evaluate the model using the test data loaders.
- Parameters:
gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper) – GbmlConfigPbWrapper for gbmlConfig proto
device (torch.device) – torch.device to run the evaluation on
- Return type:
gigl.src.common.types.model_eval_metrics.EvalMetricsCollection
- infer_batch(batch, device=torch.device('cpu'))[source]#
- Parameters:
device (torch.device)
- Return type:
- init_model(gbml_config_pb_wrapper, state_dict=None, device=torch.device('cuda'))[source]#
- Parameters:
gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper)
state_dict (Optional[dict])
device (torch.device)
- Return type:
torch.nn.Module
- train(gbml_config_pb_wrapper, device, profiler=None)[source]#
Main Training loop for the GraphSAGE model.
- Parameters:
gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper) – GbmlConfigPbWrapper for gbmlConfig proto
device (torch.device) – torch.device to run the training on
num_epochs – Number of epochs to train the model for
profiler – Profiler object to profile the training
- validate(main_data_loader, random_negative_data_loader, device)[source]#
Get the validation loss for the model using the similarity scores for the positive and negative samples.
- Parameters:
main_data_loader (torch.utils.data.dataloader._BaseDataLoaderIter) – DataLoader for the positive samples
random_negative_data_loader (torch.utils.data.dataloader._BaseDataLoaderIter) – DataLoader for the random negative samples
device (torch.device) – torch.device to run the validation on
- Returns:
Average validation loss
- Return type:
float