Source code for xmodaler.engine.rl_beam_trainer

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Jianjie Luo
@contact: jianjieluo.sysu@gmail.com
"""
import time
import torch
import numpy as np
from .defaults import DefaultTrainer
from xmodaler.scorer import build_scorer
from xmodaler.config import kfg
import xmodaler.utils.comm as comm
from .build import ENGINE_REGISTRY

__all__ = ['RLBeamTrainer']

[docs]@ENGINE_REGISTRY.register() class RLBeamTrainer(DefaultTrainer):
[docs] def __init__(self, cfg): super(RLBeamTrainer, self).__init__(cfg) self.scorer = self.build_scorer(cfg)
[docs] @classmethod def build_scorer(cls, cfg): return build_scorer(cfg)
[docs] def run_step(self): start = time.perf_counter() try: data = next(self._train_data_loader_iter) except StopIteration: self._train_data_loader_iter = iter(self.train_data_loader) data = next(self._train_data_loader_iter) data_time = time.perf_counter() - start data = comm.unwrap_model(self.model).preprocess_batch(data) beam_size = comm.unwrap_model(self.model).beam_searcher.beam_size batch_size = data['ATT_FEATS'].shape[0] device = data['ATT_FEATS'].device self.model.train() # set out_size in beam search decode data['OUT_SIZE'] = beam_size outputs_dict = self.model(data, use_beam_search=True, output_sents=False) # repeat IDS to compute cider rewards outputs_dict[kfg.IDS] = np.repeat(np.expand_dims(np.array(outputs_dict[kfg.IDS]).flatten(), axis=1), beam_size, axis=1).flatten().tolist() outputs_dict[kfg.G_SENTS_IDS] = outputs_dict[kfg.G_SENTS_IDS].view(batch_size*beam_size, -1) rewards = self.scorer(outputs_dict) reward = rewards[kfg.REWARDS].astype(np.float32) reward = torch.from_numpy(reward).to(device).view(batch_size, beam_size) reward_baseline = torch.mean(reward, -1, keepdim=True) loss = -torch.mean(outputs_dict[kfg.G_LOGP], -1) * (reward - reward_baseline) loss = loss.mean() losses_dict = { 'BeamRewardCriterion': loss } losses = sum(losses_dict.values()) self.optimizer.zero_grad() losses.backward() rewards.pop(kfg.REWARDS) losses_dict.update(rewards) self._write_metrics(losses_dict, data_time) self.optimizer.step() if self.ema is not None: self.ema.update(self.model)