# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li, Jianjie Luo
@contact: yehaoli.sysu@gmail.com, jianjieluo.sysu@gmail.com
"""
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from xmodaler.config import configurable
from xmodaler.config import kfg
from .decode_strategy import DecodeStrategy
from .build import DECODE_STRATEGY_REGISTRY
[docs]@DECODE_STRATEGY_REGISTRY.register()
class GreedyDecoder(DecodeStrategy):
[docs] def _forward(self, batched_inputs, model):
batch_size = batched_inputs[kfg.ATT_FEATS].size(0)
is_sample = batched_inputs.get(kfg.DECODE_BY_SAMPLE, False)
inputs = batched_inputs
masks = model.get_extended_attention_mask(batched_inputs)
inputs.update(masks)
ve_out = model.visual_embed(batched_inputs)
inputs.update(ve_out)
encoder_out_v = model.encoder(inputs, mode='v')
inputs.update(encoder_out_v)
inputs = model.decoder.preprocess(inputs)
sents = Variable(torch.zeros((batch_size, self.max_seq_len), dtype=torch.long).cuda()) + self.eos_token_id
logprobs = Variable(torch.zeros(batch_size, self.max_seq_len).cuda())
wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda()) + self.bos_token_id
unfinished = wt.eq(wt)
for t in range(self.max_seq_len):
inputs.update({ kfg.G_TOKENS_IDS: wt, kfg.TIME_STEP: t })
te_out = model.token_embed(inputs)
inputs.update(te_out)
encoder_out_t = model.encoder(inputs, mode='t')
inputs.update(encoder_out_t)
decoder_out = model.decoder(inputs)
inputs.update(decoder_out)
logit = model.predictor(inputs)[kfg.G_LOGITS].view(batch_size, -1)
logprobs_t = F.log_softmax(logit, dim=-1)
if is_sample:
probs_t = torch.exp(logprobs_t)
wt = torch.multinomial(probs_t, 1)
logP_t = logprobs_t.gather(1, wt)
else:
logP_t, wt = torch.max(logprobs_t, 1)
wt = wt.view(-1).long()
unfinished = unfinished * (wt != self.eos_token_id)
wt = unfinished.type_as(wt) * wt + (1 - unfinished.type_as(wt)) * self.eos_token_id
sents[:,t] = wt
logprobs[:,t] = logP_t.view(-1)
if unfinished.sum() == 0:
break
ret = inputs
ret.update({
kfg.IDS: batched_inputs[kfg.IDS],
kfg.G_SENTS_IDS: sents,
kfg.G_LOGP: logprobs
})
return ret