Source code for xmodaler.modeling.decoder.lowrank_bilinear_decoder

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Jianjie Luo, Jingwen Chen
@contact: jianjieluo.sysu@gmail.com, chenjingwen.sysu@gmail.com
"""
import torch
from torch import nn

from xmodaler.config import configurable
from xmodaler.config import CfgNode as CN
from xmodaler.config import kfg
from xmodaler.modeling.layers import LowRankBilinearAttention
from .build import DECODER_REGISTRY
from .decoder import Decoder

__all__ = ["XLANDecoder"]

[docs]@DECODER_REGISTRY.register() class XLANDecoder(Decoder):
[docs] @configurable def __init__( self, *, hidden_size: int, ctx_drop: float, bilinear_dim: int, att_heads: int, att_mid_dim: int, att_mid_drop: float, att_embed_dropout: float, layer_num: int, act_type: str, elu_alpha: float ): super(XLANDecoder, self).__init__() self.num_layers = 2 self.hidden_size = hidden_size # First LSTM layer rnn_input_size = hidden_size + bilinear_dim self.att_lstm = nn.LSTMCell(rnn_input_size, hidden_size) self.ctx_drop = nn.Dropout(ctx_drop) # lowrank dec block self.attention = LowRankBilinearAttention( embed_dim = bilinear_dim, att_heads = att_heads, att_mid_dim = att_mid_dim, att_mid_drop = att_mid_drop, dropout = att_embed_dropout, layer_num = layer_num, act_type = act_type, elu_alpha = elu_alpha ) self.att2ctx = nn.Sequential( nn.Linear(bilinear_dim + hidden_size, 2 * hidden_size), nn.GLU() )
[docs] @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.DECODER_DIM, "ctx_drop": cfg.MODEL.PRED_DROPOUT, "bilinear_dim": cfg.MODEL.BILINEAR.DIM, "att_heads": cfg.MODEL.BILINEAR.HEAD, "att_mid_dim": cfg.MODEL.BILINEAR.DECODE.ATT_MID_DIM, "att_mid_drop": cfg.MODEL.BILINEAR.DECODE.ATT_MID_DROPOUT, "att_embed_dropout": cfg.MODEL.BILINEAR.DECODE.DROPOUT, "layer_num": cfg.MODEL.BILINEAR.DECODE.LAYERS, "act_type": cfg.MODEL.BILINEAR.ACT, "elu_alpha": cfg.MODEL.BILINEAR.ELU_ALPHA }
[docs] @classmethod def add_config(cls, cfg): cfg.MODEL.BILINEAR = CN() cfg.MODEL.BILINEAR.DIM = 1024 cfg.MODEL.BILINEAR.HEAD = 8 cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT = "relu" cfg.MODEL.BILINEAR.ACT = "celu" cfg.MODEL.BILINEAR.ELU_ALPHA = 1.3 cfg.MODEL.BILINEAR.DECODE = CN() cfg.MODEL.BILINEAR.DECODE.ATT_MID_DIM = [128, 64, 128] cfg.MODEL.BILINEAR.DECODE.ATT_MID_DROPOUT = 0.1 cfg.MODEL.BILINEAR.DECODE.DROPOUT = 0.5 cfg.MODEL.BILINEAR.DECODE.LAYERS = 1 cfg.MODEL.BILINEAR.ENCODE = CN() cfg.MODEL.BILINEAR.ENCODE.ATT_MID_DIM = [128, 64, 128] cfg.MODEL.BILINEAR.ENCODE.ATT_MID_DROPOUT = 0.1 cfg.MODEL.BILINEAR.ENCODE.DROPOUT = 0.5 cfg.MODEL.BILINEAR.ENCODE.BIFEAT_EMB_DROPOUT = 0.3 cfg.MODEL.BILINEAR.ENCODE.LAYERS = 4
[docs] def preprocess(self, batched_inputs): att_feats = batched_inputs[kfg.ATT_FEATS] keys, value2s = self.attention.precompute(att_feats, att_feats) p_att_feats = torch.cat([keys, value2s], dim=-1) init_states = self.init_states(att_feats.shape[0]) batched_inputs.update(init_states) batched_inputs.update({ kfg.P_ATT_FEATS: p_att_feats }) return batched_inputs
[docs] def forward(self, batched_inputs): wt = batched_inputs[kfg.G_TOKEN_EMBED] att_feats = batched_inputs[kfg.ATT_FEATS] att_masks = batched_inputs[kfg.ATT_MASKS] p_att_feats = batched_inputs[kfg.P_ATT_FEATS] gv_feat = batched_inputs[kfg.GLOBAL_FEATS] hidden_states = batched_inputs[kfg.G_HIDDEN_STATES] # list of tensors cell_states = batched_inputs[kfg.G_CELL_STATES] if gv_feat.shape[-1] == 1: # empty gv_feat if att_masks is not None: gv_feat = torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1) else: gv_feat = torch.mean(att_feats, 1) h_att, c_att = self.att_lstm(torch.cat([wt, gv_feat + self.ctx_drop(hidden_states[1])], 1), (hidden_states[0], cell_states[0])) att, _ = self.attention(h_att, att_feats, att_masks, p_att_feats, precompute=True) ctx_input = torch.cat([att, h_att], 1) output = self.att2ctx(ctx_input) hidden_states = [h_att, output] cell_states = [c_att, cell_states[1]] return { kfg.G_HIDDEN_STATES: hidden_states, kfg.G_CELL_STATES: cell_states }