Source code for xmodaler.modeling.meta_arch.transformer_enc_dec

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F

from xmodaler.config import configurable
from xmodaler.config import kfg
from xmodaler.functional import pad_tensor, dict_to_cuda
from ..predictor import build_v_predictor
from .base_enc_dec import BaseEncoderDecoder
from .build import META_ARCH_REGISTRY

__all__ = ["TransformerEncoderDecoder"]

[docs]@META_ARCH_REGISTRY.register() class TransformerEncoderDecoder(BaseEncoderDecoder):
[docs] @configurable def __init__( self, *, vocab_size, max_seq_len, token_embed, visual_embed, encoder, decoder, predictor, greedy_decoder, beam_searcher, v_predictor, ): super().__init__( vocab_size=vocab_size, max_seq_len=max_seq_len, token_embed=token_embed, visual_embed=visual_embed, encoder=encoder, decoder=decoder, predictor=predictor, greedy_decoder=greedy_decoder, beam_searcher=beam_searcher ) self.v_predictor = v_predictor
[docs] @classmethod def from_config(cls, cfg): ret = super().from_config(cfg) if cfg.MODEL.BERT.V_TARGET_SIZE > 0: v_predictor = build_v_predictor(cfg) else: v_predictor = None ret.update({ "v_predictor": v_predictor }) return ret
[docs] def get_extended_attention_mask(self, batched_inputs): if kfg.TOKENS_MASKS not in batched_inputs: batched_inputs[kfg.TOKENS_MASKS] = torch.ones((batched_inputs[kfg.ATT_MASKS].size(0), self.max_seq_len)).cuda() tmasks = batched_inputs[kfg.TOKENS_MASKS] seq_length = tmasks.size(-1) tmasks = tmasks.to(dtype=next(self.parameters()).dtype) ext_u_tmasks = tmasks.unsqueeze(1).unsqueeze(2) ext_u_tmasks = (1.0 - ext_u_tmasks) * -10000.0 ext_g_tmasks = torch.tril(torch.ones( (seq_length, seq_length), dtype=tmasks.dtype, device=tmasks.device)) ext_g_tmasks = ext_g_tmasks.unsqueeze(0).expand( (tmasks.size(0), seq_length, seq_length)) ext_g_tmasks = ext_g_tmasks * tmasks.unsqueeze(1) ext_g_tmasks = ext_g_tmasks.to(dtype=next(self.parameters()).dtype) ext_g_tmasks = ext_g_tmasks.unsqueeze(1) ext_g_tmasks = (1.0 - ext_g_tmasks) * -10000.0 vmasks = batched_inputs[kfg.ATT_MASKS] vmasks = vmasks.to(dtype=next(self.parameters()).dtype) vmasks = vmasks.unsqueeze(1).unsqueeze(2) ext_vmasks = (1.0 - vmasks) * -10000.0 return { kfg.TOKENS_MASKS: tmasks, kfg.EXT_U_TOKENS_MASKS: ext_u_tmasks, kfg.EXT_G_TOKENS_MASKS: ext_g_tmasks, kfg.ATT_MASKS: vmasks, kfg.EXT_ATT_MASKS: ext_vmasks }
def _forward(self, batched_inputs): inputs = batched_inputs masks = self.get_extended_attention_mask(batched_inputs) inputs.update(masks) ve_out = self.visual_embed(batched_inputs) inputs.update(ve_out) if self.encoder is not None: encoder_out_v = self.encoder(inputs, mode='v') inputs.update(encoder_out_v) if self.decoder is not None: inputs = self.decoder.preprocess(inputs) te_out = self.token_embed(batched_inputs) inputs.update(te_out) if self.encoder is not None: encoder_out_t = self.encoder(inputs, mode='t') inputs.update(encoder_out_t) if self.decoder is not None: decoder_out = self.decoder(inputs) inputs.update(decoder_out) if self.predictor is not None: tlogits = self.predictor(inputs) inputs.update(tlogits) if self.v_predictor is not None: vlogits = self.v_predictor(inputs) inputs.update(vlogits) return inputs