# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import time
import copy
import torch
from .defaults import DefaultTrainer
from xmodaler.scorer import build_scorer
from xmodaler.config import kfg
from xmodaler.losses import build_rl_losses
import xmodaler.utils.comm as comm
from .build import ENGINE_REGISTRY
__all__ = ['RLTrainer']
[docs]@ENGINE_REGISTRY.register()
class RLTrainer(DefaultTrainer):
[docs] def __init__(self, cfg):
super(RLTrainer, self).__init__(cfg)
self.scorer = self.build_scorer(cfg)
self.losses = build_rl_losses(cfg)
[docs] @classmethod
def build_scorer(cls, cfg):
return build_scorer(cfg)
[docs] def run_step(self):
start = time.perf_counter()
try:
data = next(self._train_data_loader_iter)
except StopIteration:
if comm.get_world_size() > 1:
self.train_data_loader.sampler.set_epoch(self.iter//self.iters_per_epoch)
self._train_data_loader_iter = iter(self.train_data_loader)
data = next(self._train_data_loader_iter)
data_time = time.perf_counter() - start
data = comm.unwrap_model(self.model).preprocess_batch(data)
self.model.eval()
with torch.no_grad():
bs_data = copy.copy(data)
bs_outputs_dict = self.model(bs_data, use_beam_search=False, output_sents=False)
bs_rewards = self.scorer(bs_outputs_dict)
self.model.train()
data[kfg.DECODE_BY_SAMPLE] = True
outputs_dict = self.model(data, use_beam_search=False, output_sents=False)
rewards = self.scorer(outputs_dict)
rewards = torch.from_numpy(rewards[kfg.REWARDS] - bs_rewards[kfg.REWARDS]).float().cuda()
outputs_dict.update({ kfg.REWARDS: rewards })
losses_dict = {}
for loss in self.losses:
loss_dict = loss(outputs_dict)
losses_dict.update(loss_dict)
losses = [losses_dict[k] for k in losses_dict if 'acc' not in k]
losses = sum(losses)
self.optimizer.zero_grad()
losses.backward()
bs_rewards.pop(kfg.REWARDS)
losses_dict.update(bs_rewards)
self._write_metrics(losses_dict, data_time)
self.optimizer.step()
if self.ema is not None:
self.ema.update(self.model)