Source code for xmodaler.lr_scheduler.warmup_lr

import math
from bisect import bisect_right
import warnings
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from xmodaler.config import configurable
from .build import LR_SCHEDULER_REGISTRY

[docs]@LR_SCHEDULER_REGISTRY.register() class WarmupConstant(LambdaLR): """ Linear warmup and then constant. Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. Keeps learning rate schedule equal to 1. after warmup_steps. """
[docs] @configurable def __init__( self, *, optimizer, warmup_steps, last_epoch=-1): self.warmup_steps = warmup_steps super(WarmupConstant, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
[docs] @classmethod def from_config(cls, cfg, optimizer, data_size): return { "optimizer": optimizer, "warmup_steps": cfg.LR_SCHEDULER.WARMUP * data_size, "last_epoch": -1 }
[docs] def lr_lambda(self, step): if step < self.warmup_steps: return float(step) / float(max(1.0, self.warmup_steps)) return 1.
[docs]@LR_SCHEDULER_REGISTRY.register() class WarmupLinear(LambdaLR): """ Linear warmup and then linear decay. Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. """
[docs] @configurable def __init__( self, *, optimizer, min_lr, warmup_steps, t_total, last_epoch=-1): self.warmup_steps = warmup_steps self.t_total = t_total self.min_lr = min_lr super(WarmupLinear, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
[docs] @classmethod def from_config(cls, cfg, optimizer, data_size): return { "optimizer": optimizer, "min_lr": cfg.LR_SCHEDULER.MIN_LR, "warmup_steps": cfg.LR_SCHEDULER.WARMUP * data_size, "t_total": cfg.SOLVER.EPOCH * data_size, # total iterations "last_epoch": -1 }
[docs] def lr_lambda(self, step): if step < self.warmup_steps: return float(step) / float(max(1, self.warmup_steps)) return max(self.min_lr, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
[docs]@LR_SCHEDULER_REGISTRY.register() class WarmupCosine(LambdaLR): """ Linear warmup and then cosine decay. Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. """
[docs] @configurable def __init__( self, *, optimizer, min_lr, warmup_steps, t_total, cycles=.5, last_epoch=-1): self.warmup_steps = warmup_steps self.t_total = t_total self.cycles = cycles self.min_lr = min_lr super(WarmupCosine, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
[docs] @classmethod def from_config(cls, cfg, optimizer, data_size): return { "optimizer": optimizer, "min_lr": cfg.LR_SCHEDULER.MIN_LR, "warmup_steps": cfg.LR_SCHEDULER.WARMUP * data_size, "t_total": cfg.SOLVER.EPOCH * data_size, # total iterations "cycles": .5, "last_epoch": -1 }
[docs] def lr_lambda(self, step): if step < self.warmup_steps: return float(step) / float(max(1.0, self.warmup_steps)) # progress after warmup progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) return max(self.min_lr, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
[docs]@LR_SCHEDULER_REGISTRY.register() class WarmupCosineWithHardRestarts(LambdaLR): """ Linear warmup and then cosine cycles with hard restarts. Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying learning rate (with hard restarts). """
[docs] @configurable def __init__( self, *, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): self.warmup_steps = warmup_steps self.t_total = t_total self.cycles = cycles super(WarmupCosineWithHardRestarts, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
[docs] @classmethod def from_config(cls, cfg, optimizer, data_size): return { "optimizer": optimizer, "warmup_steps": cfg.LR_SCHEDULER.WARMUP * data_size, "t_total": cfg.SOLVER.EPOCH * data_size, # total iterations "cycles": 1., "last_epoch": -1 }
[docs] def lr_lambda(self, step): if step < self.warmup_steps: return float(step) / float(max(1, self.warmup_steps)) # progress after warmup progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) if progress >= 1.0: return 0.0 return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0))))
[docs]@LR_SCHEDULER_REGISTRY.register() class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
[docs] @configurable def __init__( self, *, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, warmup_iters=500, warmup_method="linear", last_epoch=-1, ): if not list(milestones) == sorted(milestones): raise ValueError( "Milestones should be a list of" " increasing integers. Got {}", milestones, ) if warmup_method not in ("constant", "linear"): raise ValueError( "Only 'constant' or 'linear' warmup_method accepted" "got {}".format(warmup_method) ) self.milestones = milestones self.gamma = gamma self.warmup_factor = warmup_factor self.warmup_iters = warmup_iters self.warmup_method = warmup_method super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
[docs] @classmethod def from_config(cls, cfg, optimizer, data_size): steps = [step * data_size for step in cfg.LR_SCHEDULER.STEPS] return { "optimizer": optimizer, "milestones": steps, "gamma": cfg.LR_SCHEDULER.GAMMA, "warmup_factor": cfg.LR_SCHEDULER.WARMUP_FACTOR, "warmup_iters": cfg.LR_SCHEDULER.WARMUP * data_size, "warmup_method": cfg.LR_SCHEDULER.WARMUP_METHOD, "last_epoch": -1 }
[docs] def get_lr(self): warmup_factor = 1 if self.last_epoch < self.warmup_iters: if self.warmup_method == "constant": warmup_factor = self.warmup_factor elif self.warmup_method == "linear": alpha = self.last_epoch / self.warmup_iters warmup_factor = self.warmup_factor * (1 - alpha) + alpha return [ base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) for base_lr in self.base_lrs ]