Source code for xmodaler.optim.sgd

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
from xmodaler.config import configurable
from .build import SOLVER_REGISTRY

[docs]@SOLVER_REGISTRY.register() class SGD(torch.optim.SGD):
[docs] @configurable def __init__( self, *, params, lr=0.1, momentum=0, dampening=0, weight_decay=0, nesterov=False ): super(SGD, self).__init__( params, lr, momentum, dampening, weight_decay, nesterov )
[docs] @classmethod def from_config(cls, cfg, params): return { "params": params, "lr": cfg.SOLVER.BASE_LR, "momentum": cfg.SOLVER.MOMENTUM, "dampening": cfg.SOLVER.DAMPENING, "weight_decay": cfg.SOLVER.WEIGHT_DECAY, "nesterov": cfg.SOLVER.NESTEROV }