Source code for xmodaler.modeling.predictor.multimodal_predictor

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li, Jianjie Luo
@contact: yehaoli.sysu@gmail.com, jianjieluo.sysu@gmail.com
"""
import torch
from torch import nn

from xmodaler.config import configurable
from xmodaler.config import CfgNode as CN
from xmodaler.config import kfg
from .build import PREDICTOR_REGISTRY
from ..layers.attention_pooler import AttentionPooler
from ..layers.bert import BertPooler

__all__ = ["MultiModalPredictor", "SingleStreamMultiModalPredictor"]

[docs]@PREDICTOR_REGISTRY.register() class MultiModalPredictor(nn.Module):
[docs] @configurable def __init__( self, *, labels_num: int, pooler_input_size: int, pooler_output_size: int, pooler_bn: bool, pooler_dropout: float, num_understanding_layers: int, num_generation_layers: int, ): super(MultiModalPredictor, self).__init__() self.num_understanding_layers = num_understanding_layers self.num_generation_layers = num_generation_layers if self.num_understanding_layers > 0: self.u_pooler = AttentionPooler( hidden_size = pooler_input_size, output_size = pooler_output_size, dropout = pooler_dropout, use_bn = pooler_bn ) self.v_pooler = nn.ModuleList( [ AttentionPooler( hidden_size = pooler_input_size, output_size = pooler_output_size, dropout = pooler_dropout, use_bn = pooler_bn ) for _ in range(self.num_understanding_layers) ] ) self.u_logits = nn.Sequential( nn.LayerNorm(pooler_output_size), nn.Linear(pooler_output_size, labels_num) ) if self.num_generation_layers > 0: self.g_pooler = AttentionPooler( hidden_size = pooler_input_size, output_size = pooler_output_size, dropout = pooler_dropout, use_bn = pooler_bn ) self.g_logits = nn.Sequential( nn.LayerNorm(pooler_output_size), nn.Linear(pooler_output_size, labels_num) )
[docs] @classmethod def from_config(cls, cfg): return { "labels_num": cfg.MODEL.MM_PREDICTOR.LABELS_NUM, "pooler_input_size": cfg.MODEL.MM_PREDICTOR.POOLER_INPUT_SIZE, "pooler_output_size": cfg.MODEL.MM_PREDICTOR.POOLER_OUTPUT_SIZE, "pooler_bn": cfg.MODEL.MM_PREDICTOR.POOLER_BN, "pooler_dropout": cfg.MODEL.MM_PREDICTOR.POOLER_DROPOUT, "num_understanding_layers": cfg.MODEL.BERT.NUM_UNDERSTANDING_LAYERS, "num_generation_layers": cfg.MODEL.BERT.NUM_GENERATION_LAYERS, }
[docs] @classmethod def add_config(cls, cfg): cfg.MODEL.MM_PREDICTOR = CN() cfg.MODEL.MM_PREDICTOR.LABELS_NUM = 3129 cfg.MODEL.MM_PREDICTOR.POOLER_INPUT_SIZE = 768 cfg.MODEL.MM_PREDICTOR.POOLER_OUTPUT_SIZE = 768 cfg.MODEL.MM_PREDICTOR.POOLER_BN = False cfg.MODEL.MM_PREDICTOR.POOLER_DROPOUT = 0.1
[docs] def forward(self, batched_inputs): outputs = 0 ret = {} if kfg.U_HIDDEN_STATES in batched_inputs: u_tfeats = batched_inputs[kfg.U_HIDDEN_STATES] ext_u_tmasks = batched_inputs[kfg.EXT_U_TOKENS_MASKS] vfeats_arr = batched_inputs[kfg.ATT_FEATS] ext_vmasks = batched_inputs[kfg.EXT_ATT_MASKS] if isinstance(u_tfeats, list): u_tfeats = u_tfeats[-1] vfeats = 0 for i in range(self.num_understanding_layers): vfeats = vfeats + self.v_pooler[i](vfeats_arr[i], ext_vmasks) u_tfeats = self.u_pooler(u_tfeats, ext_u_tmasks) pooled_output = vfeats * u_tfeats u_logits = self.u_logits(pooled_output) ret.update({ kfg.U_LOGITS: u_logits }) if not self.training: outputs = outputs + torch.softmax(u_logits, dim=-1) if kfg.G_HIDDEN_STATES in batched_inputs: g_tfeats = batched_inputs[kfg.G_HIDDEN_STATES] ext_g_tmasks = batched_inputs[kfg.EXT_G_TOKENS_MASKS] if isinstance(g_tfeats, list): g_tfeats = g_tfeats[-1] g_tfeats = self.g_pooler(g_tfeats, ext_g_tmasks) g_logits = self.g_logits(g_tfeats) ret.update({ kfg.G_LOGITS: g_logits }) if not self.training: outputs = outputs + torch.softmax(g_logits, dim=-1) outputs = torch.max(outputs, 1)[1].data.cpu().numpy() ret.update({ kfg.OUTPUT: outputs }) return ret
[docs]@PREDICTOR_REGISTRY.register() class SingleStreamMultiModalPredictor(nn.Module):
[docs] @configurable def __init__( self, *, hidden_size: int, labels_num: int, pooler ): super(SingleStreamMultiModalPredictor, self).__init__() self.labels_num = labels_num self.pooler = pooler self.cls = nn.Linear(hidden_size, labels_num)
[docs] @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.BERT.HIDDEN_SIZE, "labels_num": cfg.MODEL.MM_PREDICTOR.LABELS_NUM, "pooler": BertPooler(cfg) }
[docs] @classmethod def add_config(cls, cfg): cfg.MODEL.MM_PREDICTOR = CN() cfg.MODEL.MM_PREDICTOR.LABELS_NUM = 3129
[docs] def test_forward(self, u_logits): # for VQA outputs = torch.softmax(u_logits, dim=-1) outputs = torch.max(outputs, 1)[1].data.cpu().numpy() return { kfg.OUTPUT: outputs }
[docs] def forward(self, batched_inputs): ret = {} hidden_states = batched_inputs[kfg.U_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] pooled_output = self.pooler(hidden_states) u_logits = self.cls(pooled_output) ret.update({ kfg.U_LOGITS: u_logits }) if not self.training: ret_test = self.test_forward(u_logits) ret.update(ret_test) return ret