Source code for xmodaler.modeling.decoder.meshed_decoder

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Jianjie Luo
@contact: jianjieluo.sysu@gmail.com
"""
import torch
from torch import nn

from xmodaler.config import configurable
from xmodaler.config import CfgNode as CN
from xmodaler.config import kfg
from ..layers.multihead_attention import MultiHeadAttention
from ..layers.positionwise_feedforward import PositionWiseFeedForward

from .decoder import Decoder
from .build import DECODER_REGISTRY

import numpy as np

__all__ = ["MeshedDecoder"]

class MeshedDecoderLayer(nn.Module):
    def __init__(
        self, 
        *,
        d_model=512, 
        num_head=8, 
        d_ff=2048, 
        dropout=.1,
        enc_layer_num=3,
    ):
        super(MeshedDecoderLayer, self).__init__()

        d_k = d_v = d_model // num_head

        self.self_att = MultiHeadAttention( d_model=d_model, 
                                            d_k=d_k, 
                                            d_v=d_v, 
                                            num_head=num_head, 
                                            dropout=dropout
                                        )
        self.enc_att = MultiHeadAttention(  d_model=d_model, 
                                            d_k=d_k, 
                                            d_v=d_v, 
                                            num_head=num_head, 
                                            dropout=dropout
                                        )
                                        
        self.pwff = PositionWiseFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)

        self.fc_alpha = nn.ModuleList()
        for _ in range(enc_layer_num):
            self.fc_alpha.append(nn.Linear(2 * d_model, d_model))

        # init fc_alpha weights
        for i in range(enc_layer_num):
            nn.init.xavier_uniform_(self.fc_alpha[i].weight)
            nn.init.constant_(self.fc_alpha[i].bias, 0)

    def forward(self, input, enc_output, mask_self_att, mask_enc_att, history_states=None):
        self_att = self.self_att(input, input, input, mask_self_att, history_states=history_states)

        # cal attention on each encoder layer then weighted sum
        enc_att = 0
        for i in range(len(self.fc_alpha)):
            enc_att_k = self.enc_att(self_att, keys=enc_output[:, i], values=enc_output[:, i], attention_mask=mask_enc_att)
            alpha_k = torch.sigmoid(self.fc_alpha[i](torch.cat([self_att, enc_att_k], -1)))
            enc_att += enc_att_k * alpha_k
        enc_att = enc_att / np.sqrt(len(self.fc_alpha))

        ff = self.pwff(enc_att)
        return ff

[docs]@DECODER_REGISTRY.register() class MeshedDecoder(Decoder):
[docs] @configurable def __init__( self, *, d_model: int , num_layer: int, num_att_head: int, d_ff: int, dropout: float, padding_idx: int, # -1 enc_layer_num: int ): super(MeshedDecoder, self).__init__() self.num_layers = num_layer self.d_model = d_model self.num_att_head = num_att_head self.d_ff = d_ff self.dropout = dropout self.padding_idx = padding_idx self.layers = nn.ModuleList([ MeshedDecoderLayer( d_model=self.d_model, num_head=self.num_att_head, d_ff=self.d_ff, dropout=self.dropout, enc_layer_num=enc_layer_num ) for _ in range(self.num_layers) ])
[docs] @classmethod def from_config(cls, cfg): return { "d_model": cfg.MODEL.MESHEDMEORY.DECODER.DIM_MODEL, "num_layer": cfg.MODEL.MESHEDMEORY.DECODER.NUM_LAYER, "num_att_head": cfg.MODEL.MESHEDMEORY.DECODER.NUM_ATT_HEAD, "d_ff": cfg.MODEL.MESHEDMEORY.DECODER.DIM_FEEDFORWARD, "dropout": cfg.MODEL.MESHEDMEORY.DECODER.DROPOUT, "padding_idx": -1, # default "enc_layer_num": cfg.MODEL.MESHEDMEORY.ENCODER.NUM_LAYER }
[docs] @classmethod def add_config(cls, cfg): if not hasattr(cfg.MODEL, "MESHEDMEORY"): cfg.MODEL.MESHEDMEORY = CN() cfg.MODEL.MESHEDMEORY.DECODER = CN() cfg.MODEL.MESHEDMEORY.DECODER.DIM_MODEL = 512 cfg.MODEL.MESHEDMEORY.DECODER.NUM_LAYER = 3 cfg.MODEL.MESHEDMEORY.DECODER.DROPOUT = 0.1 cfg.MODEL.MESHEDMEORY.DECODER.NUM_ATT_HEAD = 8 cfg.MODEL.MESHEDMEORY.DECODER.DIM_FEEDFORWARD = 2048
[docs] def forward(self, batched_inputs): ret = {} vfeats = batched_inputs[kfg.ATT_FEATS] vmasks = batched_inputs[kfg.ATT_MASKS] history_states = batched_inputs.get(kfg.HISTORY_STATES, None) g_tfeats_arr = [] g_tfeats = batched_inputs[kfg.G_TOKEN_EMBED] ext_g_tmasks = batched_inputs[kfg.EXT_G_TOKENS_MASKS] ext_g_tmasks = (ext_g_tmasks == -10000.0) # FIXME if len(g_tfeats.size()) == 2: g_tfeats = g_tfeats.unsqueeze(1) if kfg.TIME_STEP in batched_inputs: time_step = batched_inputs[kfg.TIME_STEP] ext_g_tmasks = ext_g_tmasks[:,:, time_step:time_step+1, 0:time_step+1] if kfg.HISTORY_STATES not in batched_inputs: shape = list(g_tfeats.size()) shape[1] = 0 history_states = [g_tfeats.new(torch.Size(shape))] * self.num_layers batched_inputs[kfg.HISTORY_STATES] = history_states else: history_states = [None] * self.num_layers for i, layer_module in enumerate(self.layers): if history_states[i] is not None: history_states[i] = torch.cat([history_states[i], g_tfeats], dim=1) g_tfeats = layer_module(g_tfeats, vfeats, ext_g_tmasks, vmasks, history_states[i]) g_tfeats_arr.append(g_tfeats) ret.update({ kfg.G_HIDDEN_STATES: g_tfeats_arr }) return ret