Source code for xmodaler.modeling.decoder.decouple_bert_decoder

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li, Jianjie Luo
@contact: yehaoli.sysu@gmail.com, jianjieluo.sysu@gmail.com
"""
import random
import torch
from torch import nn

from xmodaler.config import configurable
from xmodaler.config import CfgNode as CN
from xmodaler.config import kfg
from .decoder import Decoder
from ..layers.bert import BertUnderstandingLayer, BertGenerationLayer
from .build import DECODER_REGISTRY

__all__ = ["DecoupleBertDecoder"]

[docs]@DECODER_REGISTRY.register() class DecoupleBertDecoder(Decoder):
[docs] @configurable def __init__( self, *, num_understanding_layers: int, num_generation_layers: int, u_layer_drop: float, g_layer_drop: float, bert_understanding_layers, bert_generation_layers ): super(DecoupleBertDecoder, self).__init__() self.num_understanding_layers = num_understanding_layers self.num_generation_layers = num_generation_layers self.u_layer_drop = u_layer_drop self.g_layer_drop = g_layer_drop if self.num_understanding_layers > 0: self.u_layers = bert_understanding_layers if self.num_generation_layers > 0: self.g_layers = bert_generation_layers
[docs] @classmethod def from_config(cls, cfg): bert_understanding_layers = nn.ModuleList( [BertUnderstandingLayer(cfg) for _ in range(cfg.MODEL.BERT.NUM_UNDERSTANDING_LAYERS)] ) bert_generation_layers = nn.ModuleList( [BertGenerationLayer(cfg) for _ in range(cfg.MODEL.BERT.NUM_GENERATION_LAYERS)] ) return { "num_understanding_layers": cfg.MODEL.BERT.NUM_UNDERSTANDING_LAYERS, "num_generation_layers": cfg.MODEL.BERT.NUM_GENERATION_LAYERS, "u_layer_drop": cfg.MODEL.BERT.U_LAYER_DROP, "g_layer_drop": cfg.MODEL.BERT.G_LAYER_DROP, "bert_understanding_layers": bert_understanding_layers, "bert_generation_layers": bert_generation_layers, }
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, batched_inputs): ret = {} vfeats = batched_inputs[kfg.ATT_FEATS] ext_vmasks = batched_inputs[kfg.EXT_ATT_MASKS] if isinstance(vfeats, list): vfeats = vfeats[-1] if (kfg.G_TOKEN_EMBED not in batched_inputs) and (self.num_generation_layers > 0): batched_inputs[kfg.G_TOKEN_EMBED] = batched_inputs[kfg.U_TOKEN_EMBED] if (kfg.U_TOKEN_EMBED not in batched_inputs) and (self.num_understanding_layers > 0): batched_inputs[kfg.U_TOKEN_EMBED] = batched_inputs[kfg.G_TOKEN_EMBED] if kfg.U_TOKEN_EMBED in batched_inputs: u_tfeats = batched_inputs[kfg.U_TOKEN_EMBED] ext_u_tmasks = batched_inputs[kfg.EXT_U_TOKENS_MASKS] if isinstance(u_tfeats, list): u_tfeats = u_tfeats[-1] vfeats_arr = [] u_tfeats_arr = [] for layer_module in self.u_layers: dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.u_layer_drop): vfeats_arr.append(vfeats) u_tfeats_arr.append(u_tfeats) else: vfeats, u_tfeats = layer_module(vfeats, ext_vmasks, u_tfeats, ext_u_tmasks) vfeats_arr.append(vfeats) u_tfeats_arr.append(u_tfeats) ret.update({ kfg.U_HIDDEN_STATES: u_tfeats_arr, kfg.ATT_FEATS: vfeats_arr }) if kfg.G_TOKEN_EMBED in batched_inputs: g_tfeats = batched_inputs[kfg.G_TOKEN_EMBED] ext_g_tmasks = batched_inputs[kfg.EXT_G_TOKENS_MASKS] if isinstance(g_tfeats, list): g_tfeats = g_tfeats[-1] if len(g_tfeats.size()) == 2: g_tfeats = g_tfeats.unsqueeze(1) history_states = batched_inputs.get(kfg.HISTORY_STATES, None) if kfg.TIME_STEP in batched_inputs: time_step = batched_inputs[kfg.TIME_STEP] ext_g_tmasks = ext_g_tmasks[:,:, time_step:time_step+1, 0:time_step+1] if kfg.HISTORY_STATES not in batched_inputs: shape = list(g_tfeats.size()) shape[1] = 0 history_states = [g_tfeats.new(torch.Size(shape))] * self.num_generation_layers batched_inputs[kfg.HISTORY_STATES] = history_states else: history_states = [None] * self.num_generation_layers g_tfeats_arr = [] for i, layer_module in enumerate(self.g_layers): if history_states[i] is not None: history_states[i] = torch.cat([history_states[i], g_tfeats], dim=1) g_tfeats = layer_module(g_tfeats, vfeats, ext_g_tmasks, ext_vmasks, history_states[i]) g_tfeats_arr.append(g_tfeats) ret.update({ kfg.G_HIDDEN_STATES: g_tfeats_arr }) return ret