Source code for xmodaler.lr_scheduler.noam_lr

import torch
from xmodaler.config import configurable
from .build import LR_SCHEDULER_REGISTRY

[docs]@LR_SCHEDULER_REGISTRY.register() class NoamLR(torch.optim.lr_scheduler._LRScheduler):
[docs] @configurable def __init__( self, *, optimizer, model_size, factor, warmup, last_epoch=-1, ): self.warmup = warmup self.factor = factor self.model_size = model_size super(NoamLR, self).__init__(optimizer, last_epoch)
[docs] @classmethod def from_config(cls, cfg, optimizer, data_size): return { "optimizer": optimizer, "model_size": cfg.LR_SCHEDULER.MODEL_SIZE, "factor": cfg.LR_SCHEDULER.FACTOR, "warmup": cfg.LR_SCHEDULER.WARMUP, # iterations "last_epoch": -1 }
[docs] def get_lr(self): return [ self.factor * \ (self.model_size ** (-0.5) * min((self.last_epoch + 1) ** (-0.5), (self.last_epoch + 1) * self.warmup ** (-1.5))) for base_lr in self.base_lrs ]