# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from xmodaler.config import configurable
from xmodaler.config import kfg
from .build import LOSSES_REGISTRY
[docs]@LOSSES_REGISTRY.register()
class LabelSmoothing(nn.Module):
[docs] @configurable
def __init__(
self,
*,
label_smoothing
):
super(LabelSmoothing, self).__init__()
self.label_smoothing = label_smoothing
self.confidence = 1.0 - self.label_smoothing
self.criterion = nn.KLDivLoss(reduction='none')
[docs] @classmethod
def from_config(cls, cfg):
return {
"label_smoothing": cfg.LOSSES.LABELSMOOTHING
}
[docs] @classmethod
def add_config(cls, cfg):
pass
[docs] def Forward(self, logits, targets):
logP = F.log_softmax(logits.view(-1, logits.shape[-1]), dim=-1)
targets = targets.view(-1)
mask = targets >= 0
assign_seq = targets #.type(torch.cuda.LongTensor)
assign_seq[assign_seq < 0] = 0
size = logP.size(1)
true_dist = logP.clone()
true_dist.fill_(self.label_smoothing / (size - 1))
true_dist.scatter_(1, assign_seq.data.unsqueeze(1), self.confidence)
loss = self.criterion(logP, true_dist).sum(1)
loss = torch.masked_select(loss, mask).mean()
return loss
[docs] def forward(self, outputs_dict):
ret = {}
if kfg.G_LOGITS in outputs_dict:
logits = outputs_dict[kfg.G_LOGITS]
targets = outputs_dict[kfg.G_TARGET_IDS]
loss = self.Forward(logits, targets)
ret.update({ 'LabelSmoothing(G) loss': loss })
if kfg.U_LOGITS in outputs_dict:
logits = outputs_dict[kfg.U_LOGITS]
targets = outputs_dict[kfg.U_TARGET_IDS]
loss = self.Forward(logits, targets)
ret.update({ 'LabelSmoothing(U) loss': loss })
return ret