# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li, Jianjie Luo
@contact: yehaoli.sysu@gmail.com, jianjieluo.sysu@gmail.com
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from xmodaler.config import configurable
from xmodaler.config import kfg
from .build import LOSSES_REGISTRY
from .triplet import BatchTriplet
[docs]@LOSSES_REGISTRY.register()
class PretrainLosses(nn.Module):
[docs] @configurable
def __init__(self, margin, max_violation):
super(PretrainLosses, self).__init__()
self.xe_loss = nn.CrossEntropyLoss(ignore_index=-1)
self.kl_loss = nn.KLDivLoss(reduction="none")
self.triplet_loss = BatchTriplet(margin, max_violation)
self.mse_loss = nn.MSELoss(reduction="mean")
[docs] @classmethod
def from_config(cls, cfg):
return {
"margin": cfg.LOSSES.MARGIN,
"max_violation": cfg.LOSSES.MAX_VIOLATION
}
[docs] @classmethod
def add_config(cls, cfg):
pass
[docs] def select_logits_targets_by_mask(self, tensor, target, mask):
tensor = tensor[mask, :]
target = target[mask]
return tensor, target
[docs] def forward(self, batched_inputs):
ret = {}
if kfg.OUTPUT in batched_inputs:
triplet_loss = self.triplet_loss(batched_inputs)
triplet_loss['BatchTriplet Loss'] /= len(batched_inputs[kfg.IDS])
ret.update(triplet_loss)
if kfg.ITM_LOGITS in batched_inputs:
is_match_score = batched_inputs[kfg.ITM_LOGITS]
itm_neg_label = batched_inputs[kfg.ITM_NEG_LABEL]
is_match_loss = self.xe_loss(
is_match_score.view(-1, 2), itm_neg_label.view(-1)
)
ret.update({ "Image Text Matching": is_match_loss })
if kfg.V_REGRESS in batched_inputs:
v_reg = batched_inputs[kfg.V_REGRESS]
v_targets = batched_inputs[kfg.V_TARGET]
if v_targets.size(1) + 1 == v_reg.size(1):
# remove global avg vfeat
v_reg = v_reg[:, 1:, :].reshape(-1, v_reg.size(-1))
else:
v_reg = v_reg.view(-1, v_reg.size(-1))
v_targets = v_targets.view(-1, v_targets.size(-1))
v_targets_labels = batched_inputs[kfg.V_TARGET_LABELS].view(-1)
v_reg, v_targets = self.select_logits_targets_by_mask(v_reg, v_targets, v_targets_labels > 0)
if v_targets.size(0) > 0:
v_loss = self.mse_loss(v_reg, v_targets)
ret.update({ "Masked Object Feature Regression": v_loss })
if kfg.V_LOGITS in batched_inputs:
v_logits = batched_inputs[kfg.V_LOGITS]
v_targets = batched_inputs[kfg.V_TARGET]
if v_targets.size(1) + 1 == v_logits.size(1):
# remove global avg vfeat
v_logits = v_logits[:, 1:, :].reshape(-1, v_logits.size(-1))
else:
v_logits = v_logits.view(-1, v_logits.size(-1))
v_targets = v_targets.view(-1, v_targets.size(-1))
v_targets_labels = batched_inputs[kfg.V_TARGET_LABELS].view(-1)
v_logits, v_targets = self.select_logits_targets_by_mask(v_logits, v_targets, v_targets_labels > 0)
if v_targets.size(0) > 0:
v_loss = self.kl_loss(F.log_softmax(v_logits, dim=-1), v_targets)
v_loss = torch.sum(v_loss) / v_loss.size(0)
ret.update({ "Masked Object Classification": v_loss })
if kfg.U_LOGITS in batched_inputs:
u_tlogits = batched_inputs[kfg.U_LOGITS]
u_tlogits = u_tlogits.view(-1, u_tlogits.size(-1))
u_targets = batched_inputs[kfg.U_TARGET_IDS].view(-1)
u_tlogits, u_targets = self.select_logits_targets_by_mask(u_tlogits, u_targets, u_targets >= 0)
if u_targets.size(0) > 0:
u_loss = self.xe_loss(u_tlogits, u_targets)
ret.update({ "Masked Language Modeling": u_loss })
if kfg.G_LOGITS in batched_inputs:
g_tlogits = batched_inputs[kfg.G_LOGITS]
g_tlogits = g_tlogits.view(-1, g_tlogits.size(-1))
g_targets = batched_inputs[kfg.G_TARGET_IDS].view(-1)
g_tlogits, g_targets = self.select_logits_targets_by_mask(g_tlogits, g_targets, g_targets >= 0)
if g_targets.size(0) > 0:
g_loss = self.xe_loss(g_tlogits, g_targets)
ret.update({ "Masked Sentence Generation": g_loss })
if len(ret) == 0:
print("No Loss in this Iteration")
ret.update({ "No Loss in this Iteration": torch.tensor(0).cuda() })
return ret