# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
from torch import nn
from xmodaler.config import configurable
from xmodaler.config import CfgNode as CN
from xmodaler.config import kfg
from ..layers.base_attention import BaseAttention
from .decoder import Decoder
from .build import DECODER_REGISTRY
__all__ = ["UpDownDecoder"]
[docs]@DECODER_REGISTRY.register()
class UpDownDecoder(Decoder):
[docs] @configurable
def __init__(
self,
*,
hidden_size: int,
token_embed_dim: int,
visual_embed_dim: int,
att_embed_size: int,
dropout1: float,
dropout2: float,
att_embed_dropout: float
):
super(UpDownDecoder, self).__init__()
self.num_layers = 2
self.hidden_size = hidden_size
in_dim = hidden_size + token_embed_dim + visual_embed_dim
self.lstm1 = nn.LSTMCell(in_dim, hidden_size)
self.dropout1 = nn.Dropout(dropout1) if dropout1 > 0 else None
in_dim = hidden_size + visual_embed_dim
self.lstm2 = nn.LSTMCell(in_dim, hidden_size)
self.dropout2 = nn.Dropout(dropout2) if dropout2 > 0 else None
self.att = BaseAttention(
hidden_size = hidden_size,
att_embed_size = att_embed_size,
att_embed_dropout = att_embed_dropout
)
self.p_att_feats = nn.Linear(visual_embed_dim, att_embed_size)
[docs] @classmethod
def from_config(cls, cfg):
return {
"hidden_size": cfg.MODEL.DECODER_DIM,
"token_embed_dim": cfg.MODEL.TOKEN_EMBED.DIM,
"visual_embed_dim": cfg.MODEL.VISUAL_EMBED.OUT_DIM,
"att_embed_size": cfg.MODEL.UPDOWN.ATT_EMBED_SIZE,
"dropout1": cfg.MODEL.UPDOWN.DROPOUT1,
"dropout2": cfg.MODEL.UPDOWN.DROPOUT2,
"att_embed_dropout": cfg.MODEL.UPDOWN.ATT_EMBED_DROPOUT
}
[docs] @classmethod
def add_config(cls, cfg):
cfg.MODEL.UPDOWN = CN()
cfg.MODEL.UPDOWN.ATT_EMBED_SIZE = 512
cfg.MODEL.UPDOWN.DROPOUT1 = 0.0
cfg.MODEL.UPDOWN.DROPOUT2 = 0.0
cfg.MODEL.UPDOWN.ATT_EMBED_DROPOUT = 0.0
[docs] def preprocess(self, batched_inputs):
att_feats = batched_inputs[kfg.ATT_FEATS]
p_att_feats = self.p_att_feats(att_feats)
init_states = self.init_states(att_feats.shape[0])
batched_inputs.update(init_states)
batched_inputs.update( { kfg.P_ATT_FEATS: p_att_feats } )
return batched_inputs
[docs] def forward(self, batched_inputs):
wt = batched_inputs[kfg.G_TOKEN_EMBED]
att_feats = batched_inputs[kfg.ATT_FEATS]
ext_att_masks = batched_inputs[kfg.EXT_ATT_MASKS]
p_att_feats = batched_inputs[kfg.P_ATT_FEATS]
global_feats = batched_inputs[kfg.GLOBAL_FEATS]
hidden_states = batched_inputs[kfg.G_HIDDEN_STATES]
cell_states = batched_inputs[kfg.G_CELL_STATES]
# lstm1
h2_tm1 = hidden_states[-1]
input1 = torch.cat([h2_tm1, global_feats, wt], 1)
if self.dropout1 is not None:
input1 = self.dropout1(input1)
h1_t, c1_t = self.lstm1(input1, (hidden_states[0], cell_states[0]))
att = self.att(h1_t, att_feats, p_att_feats, ext_att_masks)
# lstm2
input2 = torch.cat([att, h1_t], 1)
if self.dropout2 is not None:
input2 = self.dropout2(input2)
h2_t, c2_t = self.lstm2(input2, (hidden_states[1], cell_states[1]))
hidden_states = [h1_t, h2_t]
cell_states = [c1_t, c2_t]
return {
kfg.G_HIDDEN_STATES: hidden_states,
kfg.G_CELL_STATES: cell_states
}