gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec#

Attributes#

Classes#

GraphSageTemplateTrainerSpec

Template Simple Training Spec that uses GraphSAGE for Node Anchor Based Link Prediction with DDP support.

Module Contents#

class gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec.GraphSageTemplateTrainerSpec(**kwargs)[source]#

Bases: gigl.src.training.v1.lib.base_trainer.BaseTrainer, gigl.src.inference.v1.lib.base_inferencer.NodeAnchorBasedLinkPredictionBaseInferencer

Template Simple Training Spec that uses GraphSAGE for Node Anchor Based Link Prediction with DDP support. Arguments are to be passed in via trainerArgs in GBML Config.

Parameters:
  • hidden_dim (int) – Hidden dimension to use for the model (default: 64)

  • num_layers (int) – Number of layers to use for the model (default: 2)

  • out_channels (int) – Output channels to use for the model (default: 64)

  • validate_every_n_batches (int) – Number of batches to validate after (default: 20)

  • num_val_batches (int) – Number of batches to validate on (default: 10)

  • num_test_batches (int) – Number of batches to test on (default: 100)

  • early_stop_patience (int) – Number of consecutive checks without improvement to trigger early stopping (default: 3)

  • num_epochs (int) – Number of epochs to train the model for (default: 5)

  • optim_lr (float) – Learning rate to use for the optimizer (default: 0.001)

  • main_sample_batch_size (int) – Batch size to use for the main samples (default: 256)

  • random_negative_batch_size (int) – Batch size to use for the random negative samples (default: 64)

  • train_main_num_workers (int) – Number of workers to use for the train main dataloader (default: 2)

  • val_main_num_workers (int) – Number of workers to use for the val main dataloader (default: 1)

eval(gbml_config_pb_wrapper, device)[source]#

Evaluate the model using the test data loaders.

Parameters:
Return type:

gigl.src.common.types.model_eval_metrics.EvalMetricsCollection

infer_batch(batch, device=torch.device('cpu'))[source]#
Parameters:
Return type:

gigl.src.inference.v1.lib.base_inferencer.InferBatchResults

init_model(gbml_config_pb_wrapper, state_dict=None, device=torch.device('cuda'))[source]#
Parameters:
Return type:

torch.nn.Module

setup_for_training()[source]#
train(gbml_config_pb_wrapper, device, profiler=None)[source]#

Main Training loop for the GraphSAGE model.

Parameters:
validate(main_data_loader, random_negative_data_loader, device)[source]#

Get the validation loss for the model using the similarity scores for the positive and negative samples.

Parameters:
  • main_data_loader (torch.utils.data.dataloader._BaseDataLoaderIter) – DataLoader for the positive samples

  • random_negative_data_loader (torch.utils.data.dataloader._BaseDataLoaderIter) – DataLoader for the random negative samples

  • device (torch.device) – torch.device to run the validation on

Returns:

Average validation loss

Return type:

float

early_stop_patience[source]#
hidden_dim[source]#
main_sample_batch_size[source]#
property model: torch.nn.Module[source]#
Return type:

torch.nn.Module

num_epochs[source]#
num_layers[source]#
num_test_batches[source]#
num_val_batches[source]#
optim_lr[source]#
out_channels[source]#
random_negative_batch_size[source]#
property supports_distributed_training: bool[source]#
Return type:

bool

validate_every_n_batches[source]#
gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec.logger[source]#