Source code for xmodaler.engine.vcr_trainer

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import time
import tqdm
import copy
import numpy as np
import itertools
import torch
from .defaults import DefaultTrainer
from xmodaler.scorer import build_scorer
from xmodaler.config import kfg
from xmodaler.losses import build_rl_losses
import xmodaler.utils.comm as comm
from xmodaler.datasets import build_xmodaler_train_loader, build_xmodaler_valtest_loader, build_dataset_mapper
from .build import ENGINE_REGISTRY

__all__ = ['VCRTrainer']

[docs]@ENGINE_REGISTRY.register() class VCRTrainer(DefaultTrainer):
[docs] def __init__(self, cfg): super(VCRTrainer, self).__init__(cfg)
[docs] @classmethod def build_train_loader(cls, cfg): q2a_dataset_mapper = build_dataset_mapper(cfg, name=cfg.DATASETS.TRAIN, stage="train;VCR_Q-A") q2a_data_loader = build_xmodaler_train_loader(cfg, dataset_mapper=q2a_dataset_mapper) qa2r_dataset_mapper = build_dataset_mapper(cfg, name=cfg.DATASETS.TRAIN, stage="train;VCR_QA-R") qa2r_data_loader = build_xmodaler_train_loader(cfg, dataset_mapper=qa2r_dataset_mapper) return [q2a_data_loader, qa2r_data_loader]
[docs] @classmethod def build_val_loader(cls, cfg): q2a_dataset_mapper = build_dataset_mapper(cfg, name=cfg.DATASETS.TRAIN, stage="val;VCR_Q-A") q2a_data_loader = build_xmodaler_valtest_loader(cfg, dataset_mapper=q2a_dataset_mapper) qa2r_dataset_mapper = build_dataset_mapper(cfg, name=cfg.DATASETS.TRAIN, stage="val;VCR_QA-R") qa2r_data_loader = build_xmodaler_valtest_loader(cfg, dataset_mapper=qa2r_dataset_mapper) return [q2a_data_loader, qa2r_data_loader]
[docs] @classmethod def test(cls, cfg, model, test_data_loader, evaluator, epoch): model.eval() results_list = [] with torch.no_grad(): for i in range(len(test_data_loader)): results = [] for data in tqdm.tqdm(test_data_loader[i]): data = comm.unwrap_model(model).preprocess_batch(data) res = model(data) u_logits = res[kfg.U_LOGITS] u_logits = u_logits.view(-1, cfg.DATALOADER.SEQ_PER_SAMPLE) questions_ids = data[kfg.IDS].reshape((-1, cfg.DATALOADER.SEQ_PER_SAMPLE))[:, 0] probs = torch.softmax(u_logits, dim=-1) outputs = torch.max(probs, 1)[1].data.cpu().numpy() targets = data[kfg.U_TARGET_IDS].view(-1).data.cpu().numpy() for id, output, target in zip(questions_ids, outputs, targets): results.append({ cfg.INFERENCE.ID_KEY: int(id), cfg.INFERENCE.VALUE: output, kfg.U_TARGET_IDS: target }) results_list.append(results) if evaluator is not None: eval_res = evaluator.eval(results_list, epoch) else: eval_res = '' model.train() return eval_res
[docs] def run_step(self): for i in range(len(self.train_data_loader)): assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" start = time.perf_counter() try: data = next(self._train_data_loader_iter_list[i]) except StopIteration: self._train_data_loader_iter_list[i] = iter(self.train_data_loader[i]) data = next(self._train_data_loader_iter_list[i]) data_time = time.perf_counter() - start data = comm.unwrap_model(self.model).preprocess_batch(data) data[kfg.SS_PROB] = self.ss_prob outputs_dict = self.model(data) u_logits = outputs_dict[kfg.U_LOGITS] u_logits = u_logits.view(-1, self.cfg.DATALOADER.SEQ_PER_SAMPLE) outputs_dict.update({ kfg.U_LOGITS: u_logits }) losses_dict = {} for loss in self.losses: loss_dict = loss(outputs_dict) losses_dict.update(loss_dict) losses = sum(losses_dict.values()) self.optimizer.zero_grad() losses.backward() self._write_metrics(losses_dict, data_time) self.optimizer.step()