Source code for xmodaler.losses.triplet

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li, Jianjie Luo
@contact: yehaoli.sysu@gmail.com, jianjieluo.sysu@gamil.com
"""
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from xmodaler.config import configurable
from xmodaler.config import kfg
from .build import LOSSES_REGISTRY


[docs]@LOSSES_REGISTRY.register() class Triplet(nn.Module):
[docs] @configurable def __init__(self, margin): super(Triplet, self).__init__() self.margin = margin
[docs] @classmethod def from_config(cls, cfg): return { "margin": cfg.LOSSES.MARGIN }
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, outputs_dict): vil_logit = outputs_dict[kfg.U_LOGITS] vil_logit = vil_logit.view(-1, outputs_dict[kfg.SAMPLE_PER_SAMPLE]) num_pairs = vil_logit.shape[1] - 1 pairs = [] for i in range(num_pairs): # NOTE: The first element is the positive pair pairs.append(torch.cat([vil_logit[:, 0].unsqueeze(-1), vil_logit[:, i+1].unsqueeze(-1)], dim=1)) pairs = torch.cat(pairs, dim=0) triplet_dist = self.margin - pairs[:, 0] + pairs[:, 1] triplet_dist = torch.masked_select(triplet_dist, triplet_dist > 0) loss = torch.mean(triplet_dist) if len(triplet_dist) > 0 else torch.sum(triplet_dist) return { 'Triplet Loss': loss }
[docs]@LOSSES_REGISTRY.register() class BatchTriplet(nn.Module):
[docs] @configurable def __init__(self, margin, max_violation): super(BatchTriplet, self).__init__() self.margin = margin self.max_violation = max_violation
[docs] @classmethod def from_config(cls, cfg): return { "margin": cfg.LOSSES.MARGIN, "max_violation": cfg.LOSSES.MAX_VIOLATION }
[docs] def forward(self, outputs_dict): scores = outputs_dict[kfg.OUTPUT] ids = np.array(outputs_dict[kfg.IDS]) labels = np.expand_dims(ids, axis=1) == np.expand_dims(ids, axis=0) labels = torch.from_numpy(labels).cuda() # compute image-sentence score matrix batch_size = scores.size(0) diagonal = scores.diag().view(batch_size, 1) d1 = diagonal.expand_as(scores) d2 = diagonal.t().expand_as(scores) # compare every diagonal score to scores in its column # caption retrieval cost_s = (self.margin + scores - d1).clamp(min=0) # compare every diagonal score to scores in its row # image retrieval cost_im = (self.margin + scores - d2).clamp(min=0) # clear diagonals mask = torch.eye(scores.size(0)) > .5 I = Variable(mask) if torch.cuda.is_available(): I = I.cuda() cost_s = cost_s.masked_fill_(I, 0) cost_im = cost_im.masked_fill_(I, 0) cost_s = cost_s.masked_fill_(labels, 0) cost_im = cost_im.masked_fill_(labels, 0) # keep the maximum violating negative for each query if self.max_violation: cost_s = cost_s.max(1)[0] cost_im = cost_im.max(0)[0] triplet_num = len(cost_s) + len(cost_im) loss = cost_s.sum() + cost_im.sum() else: cost_s = torch.masked_select(cost_s, cost_s > 0) cost_im = torch.masked_select(cost_im, cost_im > 0) triplet_num = len(cost_s) + len(cost_im) loss = cost_s.mean() + cost_im.mean() return { 'BatchTriplet Loss': loss }