Source code for xmodaler.losses.reward_criterion

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
import torch.nn as nn
from xmodaler.config import configurable
from xmodaler.config import kfg
from .build import LOSSES_REGISTRY

[docs]@LOSSES_REGISTRY.register() class RewardCriterion(nn.Module):
[docs] @configurable def __init__(self, eos_id): super(RewardCriterion, self).__init__() self.eos_id = eos_id
[docs] @classmethod def from_config(cls, cfg): return { 'eos_id': cfg.SCORER.EOS_ID }
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, outputs_dict): seq = outputs_dict[kfg.G_SENTS_IDS] logP = outputs_dict[kfg.G_LOGP] rewards = outputs_dict[kfg.REWARDS] mask = (torch.cumsum((seq == self.eos_id), dim=-1) == 0) mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1) rewards = rewards.view(-1, 1).expand_as(logP) logP = torch.masked_select(logP, mask) rewards = torch.masked_select(rewards, mask) loss = torch.mean(-logP * rewards) return { 'RewardCriterion': loss }