Source code for gigl.src.common.utils.scheduler

import numpy as np


[docs] class CosineDecayScheduler: """A cosine decay scheduler that is used instead of torch.optim.lr_scheduler.CosineAnnealingLR since we need a scheduler specifically for EMA weight averaging (momentum). This class is from https://github.com/nerdslab/bgrl/blob/dec99f8c605e3c4ae2ece57f3fa1d41f350d11a9/bgrl/scheduler.py """ def __init__(self, max_val, warmup_steps, total_steps):
[docs] self.max_val = max_val
[docs] self.warmup_steps = warmup_steps
[docs] self.total_steps = total_steps
[docs] def get(self, step): if step < self.warmup_steps: return self.max_val * step / self.warmup_steps elif self.warmup_steps <= step <= self.total_steps: return ( self.max_val * ( 1 + np.cos( (step - self.warmup_steps) * np.pi / (self.total_steps - self.warmup_steps) ) ) / 2 ) else: raise ValueError( "Step ({}) > total number of steps ({}).".format(step, self.total_steps) )