Source code for xmodaler.modeling.predictor.bert_predictor

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li, Jianjie Luo
@contact: yehaoli.sysu@gmail.com, jianjieluo.sysu@gmail.com
"""
import torch
from torch import nn
import torch.nn.functional as F

from xmodaler.config import configurable
from xmodaler.config import kfg
from ..layers.bert import BertPredictionHeadTransform, BertPooler
from .build import PREDICTOR_REGISTRY

__all__ = ["BertPredictionHead", "BertVisualPredictionHead", "BertVisualFeatureRegressionHead", "BertIsMatchedPredictor"]

[docs]@PREDICTOR_REGISTRY.register() class BertPredictionHead(nn.Module):
[docs] @configurable def __init__( self, *, hidden_size, vocab_size, transform ): super(BertPredictionHead, self).__init__() self.transform = transform # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias
[docs] @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.BERT.HIDDEN_SIZE, "vocab_size": cfg.MODEL.VOCAB_SIZE, "transform": BertPredictionHeadTransform(cfg) }
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, batched_inputs): ret = {} if kfg.U_HIDDEN_STATES in batched_inputs: hidden_states = batched_inputs[kfg.U_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] hidden_states = self.transform(hidden_states) u_logits = self.decoder(hidden_states) ret.update({ kfg.U_LOGITS: u_logits }) if kfg.G_HIDDEN_STATES in batched_inputs: hidden_states = batched_inputs[kfg.G_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] hidden_states = self.transform(hidden_states) g_logits = self.decoder(hidden_states) ret.update({ kfg.G_LOGITS: g_logits }) return ret
[docs]@PREDICTOR_REGISTRY.register() class BertVisualPredictionHead(nn.Module):
[docs] @configurable def __init__( self, *, hidden_size, v_target_size, transform ): super(BertVisualPredictionHead, self).__init__() self.transform = transform self.decoder = nn.Linear(hidden_size, v_target_size)
[docs] @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.BERT.HIDDEN_SIZE, "v_target_size": cfg.MODEL.BERT.V_TARGET_SIZE, "transform": BertPredictionHeadTransform(cfg) }
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, batched_inputs): hidden_states = batched_inputs[kfg.ATT_FEATS] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] hidden_states = self.transform(hidden_states) logits = self.decoder(hidden_states) return { kfg.V_LOGITS: logits }
[docs]@PREDICTOR_REGISTRY.register() class BertVisualFeatureRegressionHead(nn.Module):
[docs] @configurable def __init__( self, *, hidden_size, v_feat_dim, transform ): super(BertVisualFeatureRegressionHead, self).__init__() self.transform = transform self.weight = nn.Parameter(nn.Linear(hidden_size, v_feat_dim, bias=False).weight.t()) self.bias = nn.Parameter(torch.zeros(v_feat_dim))
[docs] @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.BERT.HIDDEN_SIZE, "v_feat_dim": cfg.MODEL.VISUAL_EMBED.IN_DIM, "transform": BertPredictionHeadTransform(cfg) }
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, batched_inputs): hidden_states = batched_inputs[kfg.ATT_FEATS] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] hidden_states = self.transform(hidden_states) v_reg = F.linear(hidden_states, self.weight.t(), self.bias) return { kfg.V_REGRESS: v_reg }
[docs]@PREDICTOR_REGISTRY.register() class BertIsMatchedPredictor(nn.Module):
[docs] @configurable def __init__( self, *, hidden_size, pooler ): super(BertIsMatchedPredictor, self).__init__() self.pooler = pooler self.is_match_cls = nn.Linear(hidden_size, 2)
[docs] @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.BERT.HIDDEN_SIZE, "pooler": BertPooler(cfg) }
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, batched_inputs): hidden_states = batched_inputs[kfg.U_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] pooled_output = self.pooler(hidden_states) is_match_score = self.is_match_cls(pooled_output) return { kfg.ITM_LOGITS: is_match_score }