Source code for xmodaler.modeling.encoder.single_stream_bert_encoder

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Jianjie Luo
@contact: jianjieluo.sysu@gmail.com
"""
import torch
from torch import nn

from xmodaler.config import configurable
from xmodaler.config import kfg
from ..layers.bert import BertLayer
from .build import ENCODER_REGISTRY

__all__ = ["SingleStreamBertEncoder"]

[docs]@ENCODER_REGISTRY.register() class SingleStreamBertEncoder(nn.Module):
[docs] @configurable def __init__( self, *, num_hidden_layers: int, bert_layers, ): super(SingleStreamBertEncoder, self).__init__() self.num_hidden_layers = num_hidden_layers self.layers = bert_layers
[docs] @classmethod def from_config(cls, cfg): bert_layers = nn.ModuleList( [BertLayer(cfg) for _ in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)] ) return { "num_hidden_layers": cfg.MODEL.BERT.NUM_HIDDEN_LAYERS, "bert_layers": bert_layers }
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, batched_inputs, mode=None): # 'v', 't', 'vt' ret = {} if (mode is None) or ('t' in mode): vfeats = batched_inputs[kfg.ATT_FEATS] ext_vmasks = batched_inputs[kfg.EXT_ATT_MASKS] u_tfeats = batched_inputs[kfg.U_TOKEN_EMBED] ext_u_tmasks = batched_inputs[kfg.EXT_U_TOKENS_MASKS] lang_token_num = u_tfeats.size(1) lv_feats = torch.cat([u_tfeats, vfeats], dim=1) lv_attention_mask = torch.cat([ext_u_tmasks, ext_vmasks], dim=-1) for layer_module in self.layers: lv_feats, _ = layer_module(lv_feats, lv_attention_mask) lang_feats = lv_feats[:, :lang_token_num] v_feats = lv_feats[:, lang_token_num:] ret.update({ kfg.ATT_FEATS: v_feats, kfg.U_HIDDEN_STATES: lang_feats }) return ret