# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
from torch import nn
from xmodaler.config import configurable
from xmodaler.config import kfg
from ..layers.create_act import get_act_layer
from .build import EMBEDDING_REGISTRY
from .position_embedding import build_position_encoding
__all__ = ["TokenBaseEmbedding"]
[docs]@EMBEDDING_REGISTRY.register()
class TokenBaseEmbedding(nn.Module):
[docs] @configurable
def __init__(
self,
*,
dim: int,
vocab_size: int, # include <BOS>/<EOS>
**kwargs
):
super(TokenBaseEmbedding, self).__init__()
self.embeddings = nn.Embedding(vocab_size, dim)
self.embeddings_act = kwargs.pop("embeddings_act", None)
self.embeddings_norm = kwargs.pop("embeddings_norm", None)
self.embeddings_dropout = kwargs.pop("embeddings_dropout", None)
self.embeddings_pos = kwargs.pop("embeddings_pos", None)
self.embeddings_token_type = kwargs.pop('embeddings_token_type', None)
[docs] @classmethod
def from_config(cls, cfg):
kwargs = {
"dim": cfg.MODEL.TOKEN_EMBED.DIM,
"vocab_size": cfg.MODEL.VOCAB_SIZE
}
activation_name = (cfg.MODEL.TOKEN_EMBED.ACTIVATION).lower()
if activation_name != "none":
activation = get_act_layer(activation_name)
assert activation is not None
act_kwargs = {}
if activation_name in { "elu", "celu" }:
act_kwargs["alpha"] = cfg.MODEL.TOKEN_EMBED.ELU_ALPHA
embeddings_act = activation(**act_kwargs)
kwargs['embeddings_act'] = embeddings_act
if cfg.MODEL.TOKEN_EMBED.DROPOUT > 0:
embeddings_dropout = nn.Dropout(cfg.MODEL.TOKEN_EMBED.DROPOUT)
kwargs['embeddings_dropout'] = embeddings_dropout
if cfg.MODEL.TOKEN_EMBED.USE_NORM:
embeddings_norm = nn.LayerNorm(cfg.MODEL.TOKEN_EMBED.DIM)
kwargs['embeddings_norm'] = embeddings_norm
if (cfg.MODEL.TOKEN_EMBED.POSITION).lower() != 'none':
embeddings_pos = build_position_encoding(cfg,
cfg.MODEL.TOKEN_EMBED.DIM, cfg.MODEL.TOKEN_EMBED.POSITION_MAX_LEN)
kwargs['embeddings_pos'] = embeddings_pos
if cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE > 0:
embeddings_token_type = nn.Embedding(
cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE, cfg.MODEL.TOKEN_EMBED.DIM)
kwargs['embeddings_token_type'] = embeddings_token_type
return kwargs
[docs] def forward(self, batched_inputs):
ret = {}
if kfg.U_TOKENS_IDS in batched_inputs:
u_tokens_ids = batched_inputs[kfg.U_TOKENS_IDS]
u_tokens_type = batched_inputs.get(kfg.U_TOKENS_TYPE, None)
u_token_embed = self._forward(u_tokens_ids, token_type_ids=u_tokens_type)
ret.update({ kfg.U_TOKEN_EMBED: u_token_embed })
if kfg.G_TOKENS_IDS in batched_inputs:
time_step = batched_inputs.get(kfg.TIME_STEP, None)
g_tokens_ids = batched_inputs[kfg.G_TOKENS_IDS]
g_tokens_type = batched_inputs.get(kfg.G_TOKENS_TYPE, None)
g_token_embed = self._forward(g_tokens_ids, token_type_ids=g_tokens_type, time_step=time_step)
ret.update({ kfg.G_TOKEN_EMBED: g_token_embed })
return ret
[docs] def _forward(self, input_ids, token_type_ids=None, time_step=None):
embeddings = self.embeddings(input_ids)
if self.embeddings_pos is not None:
pos_inputs = input_ids if time_step is None else time_step
position_embeddings = self.embeddings_pos(pos_inputs)
embeddings = embeddings + position_embeddings
if (self.embeddings_token_type is not None) and (token_type_ids is not None):
token_type_ids = token_type_ids if time_step is None else token_type_ids[:,time_step]
embeddings_token_type = self.embeddings_token_type(token_type_ids)
embeddings = embeddings + embeddings_token_type
if self.embeddings_act is not None:
embeddings = self.embeddings_act(embeddings)
if self.embeddings_norm is not None:
embeddings = self.embeddings_norm(embeddings)
if self.embeddings_dropout is not None:
embeddings = self.embeddings_dropout(embeddings)
return embeddings