Source code for xmodaler.modeling.encoder.transformer_encoder
# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
from torch import nn
from xmodaler.config import configurable
from xmodaler.config import CfgNode as CN
from xmodaler.config import kfg
from ..layers.bert import BertLayer
from .build import ENCODER_REGISTRY
__all__ = ["TransformerEncoder"]
[docs]@ENCODER_REGISTRY.register()
class TransformerEncoder(nn.Module):
[docs] @configurable
def __init__(
self,
*,
num_hidden_layers: int,
bert_layers,
):
super(TransformerEncoder, self).__init__()
self.num_hidden_layers = num_hidden_layers
self.layers = bert_layers
[docs] @classmethod
def from_config(cls, cfg):
bert_layers = nn.ModuleList(
[BertLayer(cfg) for _ in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)]
)
return {
"num_hidden_layers": cfg.MODEL.BERT.NUM_HIDDEN_LAYERS,
"bert_layers": bert_layers,
}
[docs] def forward(self, batched_inputs, mode=None):
ret = {}
if mode == None or mode == 'v':
vfeats = batched_inputs[kfg.ATT_FEATS]
ext_vmasks = batched_inputs[kfg.EXT_ATT_MASKS]
for layer_module in self.layers:
vfeats, _ = layer_module(vfeats, ext_vmasks)
ret.update({ kfg.ATT_FEATS: vfeats })
return ret