Source code for xmodaler.modeling.meta_arch.rnn_att_enc_dec

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F

from xmodaler.config import configurable
from xmodaler.config import CfgNode as CN
from xmodaler.config import kfg
from .base_enc_dec import BaseEncoderDecoder
from .build import META_ARCH_REGISTRY

__all__ = ["RnnAttEncoderDecoder"]

[docs]@META_ARCH_REGISTRY.register() class RnnAttEncoderDecoder(BaseEncoderDecoder):
[docs] def get_extended_attention_mask(self, batched_inputs): att_masks = batched_inputs[kfg.ATT_MASKS] if att_masks is not None: att_masks = att_masks.to(dtype=next(self.parameters()).dtype) ext_att_masks = (1.0 - att_masks) * -10000.0 else: ext_att_masks = None return { kfg.ATT_MASKS: att_masks, kfg.EXT_ATT_MASKS: ext_att_masks }
def _forward(self, batched_inputs): inputs = batched_inputs masks = self.get_extended_attention_mask(batched_inputs) inputs.update(masks) ve_out = self.visual_embed(batched_inputs) inputs.update(ve_out) if self.encoder is not None: encoder_out_v = self.encoder(inputs, mode='v') inputs.update(encoder_out_v) if self.decoder is not None: inputs = self.decoder.preprocess(inputs) tokens_ids = batched_inputs[kfg.G_TOKENS_IDS] batch_size, seq_len = tokens_ids.shape ss_prob = batched_inputs[kfg.SS_PROB] outputs = Variable(torch.zeros(batch_size, seq_len, self.vocab_size).cuda()) for t in range(seq_len): if t >= 1 and tokens_ids[:, t].max() == 0: break if self.training and t >= 1 and ss_prob > 0: prob = torch.empty(batch_size).cuda().uniform_(0, 1) mask = prob < ss_prob if mask.sum() == 0: wt = tokens_ids[:, t].clone() else: ind = mask.nonzero().view(-1) wt = tokens_ids[:, t].data.clone() prob_prev = torch.exp(outputs[:, t-1].detach()) wt.index_copy_(0, ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, ind)) else: wt = tokens_ids[:, t].clone() inputs.update({ kfg.G_TOKENS_IDS: wt }) te_out = self.token_embed(inputs) inputs.update(te_out) if self.encoder is not None: encoder_out_t = self.encoder(inputs, mode='t') inputs.update(encoder_out_t) if self.decoder is not None: decoder_out = self.decoder(inputs) inputs.update(decoder_out) if self.predictor is not None: logit = self.predictor(inputs)[kfg.G_LOGITS] outputs[:, t] = logit inputs.update({kfg.G_LOGITS: outputs}) return inputs