Source code for xmodaler.engine.retrieval_trainer

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import time
import tqdm
import copy
import numpy as np
import itertools
import torch
from .defaults import DefaultTrainer
from xmodaler.scorer import build_scorer
from xmodaler.config import kfg
from xmodaler.losses import build_rl_losses
import xmodaler.utils.comm as comm
from .build import ENGINE_REGISTRY

__all__ = ['RetrievalTrainer']

[docs]@ENGINE_REGISTRY.register() class RetrievalTrainer(DefaultTrainer):
[docs] def __init__(self, cfg): super(RetrievalTrainer, self).__init__(cfg)
[docs] @classmethod def test(cls, cfg, model, test_data_loader, evaluator, epoch): model.eval() ids = [] vfeats = [] tfeats = [] with torch.no_grad(): for data in tqdm.tqdm(test_data_loader): data = comm.unwrap_model(model).preprocess_batch(data) outputs = model(data)[kfg.OUTPUT] ids += data[kfg.IDS] vfeats.append(outputs[0]) tfeats.append(outputs[1]) iids = [ i[0] for i in ids] cids = [ i[1] for i in ids] cids = list(itertools.chain.from_iterable(cids)) labels = np.expand_dims(cids, axis=1) == np.expand_dims(iids, axis=0) labels = labels.astype(int) vfeats = torch.cat(vfeats, dim=0) tfeats = torch.cat(tfeats, dim=0) if evaluator is not None: eval_res = evaluator.eval(vfeats, tfeats, labels) model.train() return eval_res