Source code for gigl.src.common.utils.scheduler
import numpy as np
[docs]
class CosineDecayScheduler:
"""A cosine decay scheduler that is used instead of torch.optim.lr_scheduler.CosineAnnealingLR since we need a scheduler specifically for
EMA weight averaging (momentum).
This class is from https://github.com/nerdslab/bgrl/blob/dec99f8c605e3c4ae2ece57f3fa1d41f350d11a9/bgrl/scheduler.py
"""
def __init__(self, max_val, warmup_steps, total_steps):
[docs]
self.warmup_steps = warmup_steps
[docs]
self.total_steps = total_steps
[docs]
def get(self, step):
if step < self.warmup_steps:
return self.max_val * step / self.warmup_steps
elif self.warmup_steps <= step <= self.total_steps:
return (
self.max_val
* (
1
+ np.cos(
(step - self.warmup_steps)
* np.pi
/ (self.total_steps - self.warmup_steps)
)
)
/ 2
)
else:
raise ValueError(
"Step ({}) > total number of steps ({}).".format(step, self.total_steps)
)