gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec#
Attributes#
Classes#
| Template Simple Training Spec that uses GraphSAGE for Node Anchor Based Link Prediction with DDP support. | 
Module Contents#
- class gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec.GraphSageTemplateTrainerSpec(**kwargs)[source]#
- Bases: - gigl.src.training.v1.lib.base_trainer.BaseTrainer,- gigl.src.inference.v1.lib.base_inferencer.NodeAnchorBasedLinkPredictionBaseInferencer- Template Simple Training Spec that uses GraphSAGE for Node Anchor Based Link Prediction with DDP support. Arguments are to be passed in via trainerArgs in GBML Config. - Parameters:
- hidden_dim (int) – Hidden dimension to use for the model (default: 64) 
- num_layers (int) – Number of layers to use for the model (default: 2) 
- out_channels (int) – Output channels to use for the model (default: 64) 
- validate_every_n_batches (int) – Number of batches to validate after (default: 20) 
- num_val_batches (int) – Number of batches to validate on (default: 10) 
- num_test_batches (int) – Number of batches to test on (default: 100) 
- early_stop_patience (int) – Number of consecutive checks without improvement to trigger early stopping (default: 3) 
- num_epochs (int) – Number of epochs to train the model for (default: 5) 
- optim_lr (float) – Learning rate to use for the optimizer (default: 0.001) 
- main_sample_batch_size (int) – Batch size to use for the main samples (default: 256) 
- random_negative_batch_size (int) – Batch size to use for the random negative samples (default: 64) 
- train_main_num_workers (int) – Number of workers to use for the train main dataloader (default: 2) 
- val_main_num_workers (int) – Number of workers to use for the val main dataloader (default: 1) 
 
 - eval(gbml_config_pb_wrapper, device)[source]#
- Evaluate the model using the test data loaders. - Parameters:
- gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper) – GbmlConfigPbWrapper for gbmlConfig proto 
- device (torch.device) – torch.device to run the evaluation on 
 
- Return type:
- gigl.src.common.types.model_eval_metrics.EvalMetricsCollection 
 
 - infer_batch(batch, device=torch.device('cpu'))[source]#
- Parameters:
- device (torch.device) 
 
- Return type:
 
 - init_model(gbml_config_pb_wrapper, state_dict=None, device=torch.device('cuda'))[source]#
- Parameters:
- gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper) 
- state_dict (Optional[dict]) 
- device (torch.device) 
 
- Return type:
- torch.nn.Module 
 
 - train(gbml_config_pb_wrapper, device, profiler=None)[source]#
- Main Training loop for the GraphSAGE model. - Parameters:
- gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper) – GbmlConfigPbWrapper for gbmlConfig proto 
- device (torch.device) – torch.device to run the training on 
- num_epochs – Number of epochs to train the model for 
- profiler – Profiler object to profile the training 
 
 
 - validate(main_data_loader, random_negative_data_loader, device)[source]#
- Get the validation loss for the model using the similarity scores for the positive and negative samples. - Parameters:
- main_data_loader (torch.utils.data.dataloader._BaseDataLoaderIter) – DataLoader for the positive samples 
- random_negative_data_loader (torch.utils.data.dataloader._BaseDataLoaderIter) – DataLoader for the random negative samples 
- device (torch.device) – torch.device to run the validation on 
 
- Returns:
- Average validation loss 
- Return type:
- float 
 
 
