# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Jianjie Luo, Jingwen Chen
@contact: jianjieluo.sysu@gmail.com, chenjingwen.sysu@gmail.com
"""
import torch
import torch.nn as nn
from xmodaler.modeling.layers import get_act_layer
from .scattention import SCAttention
__all__ = ["LowRankBilinearAttention", "LowRankBilinearLayer"]
class LowRank(nn.Module):
def __init__(
self,
*,
embed_dim: int,
att_heads: int,
att_mid_dim: list,
att_mid_drop: float,
act_type: str,
elu_alpha: float
):
super(LowRank, self).__init__()
self.embed_dim = embed_dim
self.num_heads = att_heads
self.head_dim = embed_dim // self.num_heads
self.scaling = self.head_dim ** -0.5
output_dim = 2 * embed_dim if act_type == 'GLU' else embed_dim
sequential = []
sequential.append(nn.Linear(embed_dim, output_dim))
act = get_act_layer(act_type)(elu_alpha)
if act is not None:
sequential.append(act)
sequential.append(torch.nn.GroupNorm(self.num_heads, embed_dim))
self.in_proj_q = nn.Sequential(*sequential)
sequential = []
sequential.append(nn.Linear(embed_dim, output_dim))
act = get_act_layer(act_type)(elu_alpha)
if act is not None:
sequential.append(act)
sequential.append(torch.nn.GroupNorm(self.num_heads, embed_dim))
self.in_proj_k = nn.Sequential(*sequential)
sequential = []
sequential.append(nn.Linear(embed_dim, output_dim))
act = get_act_layer(act_type)(elu_alpha)
if act is not None:
sequential.append(act)
sequential.append(torch.nn.GroupNorm(self.num_heads, embed_dim))
self.in_proj_v1 = nn.Sequential(*sequential)
sequential = []
sequential.append(nn.Linear(embed_dim, output_dim))
act = get_act_layer(act_type)(elu_alpha)
# act = nn.CELU(elu_alpha)
if act is not None:
sequential.append(act)
sequential.append(torch.nn.GroupNorm(self.num_heads, embed_dim))
self.in_proj_v2 = nn.Sequential(*sequential)
self.attn_net = SCAttention(att_mid_dim, att_mid_drop)
# query -- batch_size * qdim
# value -- batch_size * att_num * vdim
def forward(self, query, key, mask, value1, value2, precompute=False):
batch_size = query.size()[0]
q = self.in_proj_q(query)
v1 = self.in_proj_v1(value1)
q = q.view(batch_size, self.num_heads, self.head_dim)
v1 = v1.view(batch_size, self.num_heads, self.head_dim)
if precompute == False:
key = key.view(-1, key.size()[-1])
value2 = value2.view(-1, value2.size()[-1])
k = self.in_proj_k(key)
v2 = self.in_proj_v2(value2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v2 = v2.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
else:
k = key
v2 = value2
attn_map = q.unsqueeze(-2) * k
attn = self.attn_net(attn_map, mask, v1, v2)
attn = attn.view(batch_size, self.num_heads * self.head_dim)
return attn
# query -- batch_size * seq_num * qdim
# value -- batch_size * att_num * vdim
def forward2(self, query, key, mask, value1, value2, precompute=False):
batch_size = query.size()[0]
query = query.view(-1, query.size()[-1])
value1 = value1.view(-1, value1.size()[-1])
q = self.in_proj_q(query)
v1 = self.in_proj_v1(value1)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v1 = v1.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if precompute == False:
key = key.view(-1, key.size()[-1])
value2 = value2.view(-1, value2.size()[-1])
k = self.in_proj_k(key)
v2 = self.in_proj_v2(value2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v2 = v2.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if self.buffer_keys is not None and self.buffer_value2 is not None:
self.buffer_keys = torch.cat([self.buffer_keys, k], dim=2)
self.buffer_value2 = torch.cat([self.buffer_value2, v2], dim=2)
k = self.buffer_keys
v2 = self.buffer_value2
else:
k = key
v2 = value2
attn_map = q.unsqueeze(-2) * k.unsqueeze(-3)
attn = self.attn_net.forward(attn_map, mask, v1, v2).transpose(1, 2).contiguous()
attn = attn.view(batch_size, -1, self.num_heads * self.head_dim)
return attn
def precompute(self, key, value2):
batch_size = value2.size()[0]
key = key.view(-1, key.size()[-1])
value2 = value2.view(-1, value2.size()[-1])
k = self.in_proj_k(key)
v2 = self.in_proj_v2(value2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v2 = v2.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
return k, v2
[docs]class LowRankBilinearLayer(nn.Module):
[docs] def __init__(
self,
*,
embed_dim: int,
att_heads: int,
att_mid_dim: list,
att_mid_drop: float,
dropout: float,
act_type: str,
elu_alpha: float
):
super(LowRankBilinearLayer, self).__init__()
self.encoder_attn = LowRank(
embed_dim = embed_dim,
att_heads = att_heads,
att_mid_dim = att_mid_dim,
att_mid_drop = att_mid_drop,
act_type = act_type,
elu_alpha = elu_alpha
)
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
[docs] def forward(
self,
x,
key=None,
mask=None,
value1=None,
value2=None,
precompute=False
):
x = self.encoder_attn(
query=x,
key=key if key is not None else x,
mask=mask,
value1=value1 if value1 is not None else x,
value2=value2 if value2 is not None else x,
precompute=precompute
)
if self.dropout is not None:
x = self.dropout(x)
return x
[docs] def precompute(self, key, value2):
return self.encoder_attn.precompute(key, value2)
[docs]class LowRankBilinearAttention(nn.Module):
[docs] def __init__(
self,
*,
embed_dim: int,
att_heads: int,
att_mid_dim: list,
att_mid_drop: float,
dropout: float,
layer_num: int,
act_type: str,
elu_alpha: float
):
super(LowRankBilinearAttention, self).__init__()
self.layers = nn.ModuleList([])
for _ in range(layer_num):
sublayer = LowRankBilinearLayer(
embed_dim = embed_dim,
att_heads = att_heads,
att_mid_dim = att_mid_dim,
att_mid_drop = att_mid_drop,
dropout = dropout,
act_type = act_type,
elu_alpha = elu_alpha
)
self.layers.append(sublayer)
self.proj = nn.Linear(embed_dim * (layer_num + 1), embed_dim)
self.layer_norm = torch.nn.LayerNorm(embed_dim)
[docs] def precompute(self, key, value2):
keys = []
value2s = []
for layer in self.layers:
k, v = layer.precompute(key, value2)
keys.append(k)
value2s.append(v)
return torch.cat(keys, dim=-1), torch.cat(value2s, dim=-1)
[docs] def forward(self, gv_feat, att_feats, att_mask, p_att_feats=None, precompute=False):
if precompute == True:
dim = p_att_feats.size()[-1]
keys = p_att_feats.narrow(-1, 0, dim // 2)
value2s = p_att_feats.narrow(-1, dim // 2, dim // 2)
dim = keys.size()[-1] // len(self.layers)
if gv_feat.shape[-1] == 1: # empty gv_feat
if att_mask is not None:
gv_feat = (torch.sum(att_feats * att_mask.unsqueeze(-1), 1) / torch.sum(att_mask.unsqueeze(-1), 1))
else:
gv_feat = torch.mean(att_feats, 1)
feat_arr = [gv_feat]
for i, layer in enumerate(self.layers):
key = keys.narrow(-1, i * dim, dim) if precompute else att_feats
value2 = value2s.narrow(-1, i * dim, dim) if precompute else att_feats
gv_feat = layer(gv_feat, key, att_mask, gv_feat, value2, precompute)
feat_arr.append(gv_feat)
gv_feat = torch.cat(feat_arr, dim=-1)
gv_feat = self.proj(gv_feat)
gv_feat = self.layer_norm(gv_feat)
return gv_feat, att_feats