Source code for xmodaler.losses.bce_logits

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
import torch.nn as nn
from xmodaler.config import configurable
from xmodaler.config import kfg
from .build import LOSSES_REGISTRY

[docs]@LOSSES_REGISTRY.register() class BCEWithLogits(nn.Module):
[docs] @configurable def __init__(self): super(BCEWithLogits, self).__init__() self.criterion = nn.BCEWithLogitsLoss(reduction="mean")
[docs] @classmethod def from_config(cls, cfg): return {}
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, outputs_dict): ret = {} targets = outputs_dict[kfg.U_TARGET_IDS] if kfg.G_LOGITS in outputs_dict: logits = outputs_dict[kfg.G_LOGITS] loss = self.criterion(logits, targets) * targets.size(1) ret.update({ 'BCEWithLogits Loss(G)': loss }) if kfg.U_LOGITS in outputs_dict: logits = outputs_dict[kfg.U_LOGITS] loss = self.criterion(logits, targets) * targets.size(1) ret.update({ 'BCEWithLogits Loss(U)': loss }) return ret