Source code for xmodaler.modeling.meta_arch.uniter

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Jianjie Luo
@contact: jianjieluo.sysu@gmail.com
"""
import torch
import torch.distributed as dist
import random

from xmodaler.config import configurable
from xmodaler.config import kfg
from xmodaler.utils.distributed import any_broadcast
from ..predictor import build_v_predictor, build_predictor_with_name
from .transformer_enc_dec import TransformerEncoderDecoder
from .build import META_ARCH_REGISTRY

from ..embedding import build_embeddings
from ..encoder import build_encoder
from ..predictor import build_predictor

__all__ = ["UniterPretrain", "UniterForMMUnderstanding", "UniterRetrieval"]

[docs]@META_ARCH_REGISTRY.register() class UniterForMMUnderstanding(TransformerEncoderDecoder):
[docs] @configurable def __init__( self, *, vocab_size, max_seq_len, token_embed, visual_embed, encoder, decoder, predictor, greedy_decoder, beam_searcher, v_predictor, itm_predictor, ): super().__init__( vocab_size=vocab_size, max_seq_len=max_seq_len, token_embed=token_embed, visual_embed=visual_embed, encoder=encoder, decoder=decoder, predictor=predictor, greedy_decoder=greedy_decoder, beam_searcher=beam_searcher, v_predictor=v_predictor ) self.itm_predictor = itm_predictor
[docs] @classmethod def from_config(cls, cfg): ret = { # basic config "token_embed": build_embeddings(cfg, cfg.MODEL.TOKEN_EMBED.NAME), "visual_embed": build_embeddings(cfg, cfg.MODEL.VISUAL_EMBED.NAME), "encoder": build_encoder(cfg), "decoder": None, "predictor": build_predictor(cfg), "greedy_decoder": None, "beam_searcher": None, "vocab_size": cfg.MODEL.VOCAB_SIZE, "max_seq_len": cfg.MODEL.MAX_SEQ_LEN, "v_predictor": None, # uniter pretrain config, in order to load the pretrained pooler "itm_predictor": build_predictor_with_name(cfg, 'BertIsMatchedPredictor') } return ret
[docs] def bind_or_init_weights(self): self.predictor.pooler = self.itm_predictor.pooler
[docs] def get_extended_attention_mask(self, batched_inputs): if kfg.TOKENS_MASKS not in batched_inputs: batched_inputs[kfg.TOKENS_MASKS] = torch.ones((batched_inputs[kfg.ATT_MASKS].size(0), self.max_seq_len)).cuda() tmasks = batched_inputs[kfg.TOKENS_MASKS] tmasks = tmasks.to(dtype=next(self.parameters()).dtype) ext_u_tmasks = tmasks.unsqueeze(1).unsqueeze(2) ext_u_tmasks = (1.0 - ext_u_tmasks) * -10000.0 vmasks = batched_inputs[kfg.ATT_MASKS] vmasks = vmasks.to(dtype=next(self.parameters()).dtype) vmasks = vmasks.unsqueeze(1).unsqueeze(2) ext_vmasks = (1.0 - vmasks) * -10000.0 return { kfg.TOKENS_MASKS: tmasks, kfg.EXT_U_TOKENS_MASKS: ext_u_tmasks, kfg.ATT_MASKS: vmasks, kfg.EXT_ATT_MASKS: ext_vmasks }
@META_ARCH_REGISTRY.register() class UniterRetrieval(UniterForMMUnderstanding): @configurable def __init__( self, *, vocab_size, max_seq_len, token_embed, visual_embed, encoder, decoder, predictor, greedy_decoder, beam_searcher, v_predictor, itm_predictor, ): super().__init__( vocab_size=vocab_size, max_seq_len=max_seq_len, token_embed=token_embed, visual_embed=visual_embed, encoder=encoder, decoder=decoder, predictor=predictor, greedy_decoder=greedy_decoder, beam_searcher=beam_searcher, v_predictor=v_predictor, itm_predictor=itm_predictor ) def bind_or_init_weights(self): self.predictor.pooler = self.itm_predictor.pooler self.predictor.cls.weight.data = self.itm_predictor.is_match_cls.weight.data[:1, :] self.predictor.cls.bias.data = self.itm_predictor.is_match_cls.bias.data[:1]
[docs]@META_ARCH_REGISTRY.register() class UniterPretrain(TransformerEncoderDecoder):
[docs] @configurable def __init__( self, *, vocab_size, max_seq_len, token_embed, visual_embed, encoder, decoder, predictor, greedy_decoder, beam_searcher, v_predictor, v_regressor, itm_predictor, tasks, mix_ratio ): super().__init__( vocab_size=vocab_size, max_seq_len=max_seq_len, token_embed=token_embed, visual_embed=visual_embed, encoder=encoder, decoder=decoder, predictor=predictor, greedy_decoder=greedy_decoder, beam_searcher=beam_searcher, v_predictor=v_predictor ) self.v_regressor = v_regressor self.itm_predictor = itm_predictor self.v_regressor.weight = self.visual_embed.embeddings.weight # prepare for random sample pretraining task self.sampling_pool = [] for name, r in zip(tasks, mix_ratio): self.sampling_pool.extend([name]*r) try: self.world_size = torch.distributed.get_world_size() self.distributed = True except: self.distributed = False self.world_size = 1
[docs] @classmethod def from_config(cls, cfg): assert cfg.MODEL.BERT.V_TARGET_SIZE > 0 assert len(cfg.MODEL.PRETRAIN_TASKS) == len(cfg.MODEL.PRETRAIN_TASKS_MIX_RATIO) ret = { # basic config "token_embed": build_embeddings(cfg, cfg.MODEL.TOKEN_EMBED.NAME), "visual_embed": build_embeddings(cfg, cfg.MODEL.VISUAL_EMBED.NAME), "encoder": build_encoder(cfg), "decoder": None, "predictor": build_predictor(cfg), "greedy_decoder": None, "beam_searcher": None, "vocab_size": cfg.MODEL.VOCAB_SIZE, "max_seq_len": cfg.MODEL.MAX_SEQ_LEN, # uniter pretrain config "v_predictor": build_v_predictor(cfg), "v_regressor": build_predictor_with_name(cfg, cfg.MODEL.V_REGRESSOR), "itm_predictor": build_predictor_with_name(cfg, cfg.MODEL.ITM_PREDICTOR), "tasks": tuple(cfg.MODEL.PRETRAIN_TASKS), 'mix_ratio': tuple(cfg.MODEL.PRETRAIN_TASKS_MIX_RATIO) } return ret
[docs] @classmethod def add_config(cls, cfg, tmp_cfg): super().add_config(cfg, tmp_cfg) cfg.MODEL.V_REGRESSOR = '' cfg.MODEL.ITM_PREDICTOR = '' cfg.MODEL.PRETRAIN_TASKS = ['itm', 'mlm', 'mrfr', 'mrc-kl'] cfg.MODEL.PRETRAIN_TASKS_MIX_RATIO = [1, 1, 1, 1]
[docs] def preprocess_inputs(self, inputs, task_name): if task_name == 'itm': inputs[kfg.ATT_FEATS] = inputs[kfg.ATT_FEATS_WO_MASK] inputs[kfg.U_TOKENS_IDS] = inputs[kfg.U_TOKENS_IDS_WO_MASK] elif task_name == 'mlm': inputs[kfg.ATT_FEATS] = inputs[kfg.ATT_FEATS_WO_MASK] elif task_name == 'mrfr': inputs[kfg.U_TOKENS_IDS] = inputs[kfg.U_TOKENS_IDS_WO_MASK] inputs[kfg.V_TARGET] = inputs[kfg.ATT_FEATS_WO_MASK] elif task_name == 'mrc-kl': inputs[kfg.U_TOKENS_IDS] = inputs[kfg.U_TOKENS_IDS_WO_MASK] else: raise NotImplementedError if task_name != 'itm': # mask out neg vl pair itm_neg_label = inputs[kfg.ITM_NEG_LABEL] image_label = inputs[kfg.V_TARGET_LABELS] * (itm_neg_label == 0).long().unsqueeze(1) inputs[kfg.V_TARGET_LABELS][image_label == 0] = -1 masked_lm_labels = inputs[kfg.U_TARGET_IDS] * (itm_neg_label == 0).long().unsqueeze(1) inputs[kfg.U_TARGET_IDS][masked_lm_labels == 0] = -1 return inputs
def _forward(self, batched_inputs): inputs = batched_inputs masks = self.get_extended_attention_mask(batched_inputs) inputs.update(masks) # random select a task task_name = random.choice(self.sampling_pool) if self.distributed: # make sure all process is training same task dist.barrier() task_name = any_broadcast(task_name, 0, n_gpu=self.world_size) dist.barrier() # Preprocess inputs = self.preprocess_inputs(inputs, task_name) # Forward embeddings ve_out = self.visual_embed(batched_inputs) inputs.update(ve_out) te_out = self.token_embed(batched_inputs) inputs.update(te_out) # Forward encoder encoder_out = self.encoder(inputs) inputs.update(encoder_out) # Forward Head if task_name == 'itm': scores = self.itm_predictor(inputs) inputs.update(scores) elif task_name == 'mlm': tlogits = self.predictor(inputs) inputs.update(tlogits) elif task_name == 'mrfr': vregs = self.v_regressor(inputs) inputs.update(vregs) elif task_name == 'mrc-kl': vlogits = self.v_predictor(inputs) inputs.update(vlogits) return inputs