Source code for xmodaler.losses.cross_entropy

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
import torch.nn as nn
from xmodaler.config import configurable
from xmodaler.config import kfg
from .build import LOSSES_REGISTRY

[docs]@LOSSES_REGISTRY.register() class CrossEntropy(nn.Module):
[docs] @configurable def __init__(self): super(CrossEntropy, self).__init__() self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
[docs] @classmethod def from_config(cls, cfg): return {}
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, outputs_dict): ret = {} if kfg.G_LOGITS in outputs_dict: logits = outputs_dict[kfg.G_LOGITS] targets = outputs_dict[kfg.G_TARGET_IDS] logits = logits.view(-1, logits.shape[-1]) targets = targets.view(-1).long() loss = self.criterion(logits, targets) ret.update({ 'CrossEntropy Loss(G)': loss }) if kfg.U_LOGITS in outputs_dict: logits = outputs_dict[kfg.U_LOGITS] targets = outputs_dict[kfg.U_TARGET_IDS] logits = logits.view(-1, logits.shape[-1]) targets = targets.view(-1).long() loss = self.criterion(logits, targets) ret.update({'CrossEntropy Loss(U)': loss}) return ret