Source code for xmodaler.modeling.encoder.updown_encoder

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
from torch import nn

from xmodaler.config import configurable
from xmodaler.config import CfgNode as CN
from xmodaler.config import kfg
from .build import ENCODER_REGISTRY

__all__ = ["UpDownEncoder"]

[docs]@ENCODER_REGISTRY.register() class UpDownEncoder(nn.Module):
[docs] @configurable def __init__(self): super(UpDownEncoder, self).__init__()
[docs] @classmethod def from_config(cls, cfg): return {}
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, batched_inputs, mode=None): ret = {} if mode == None or mode == 'v': att_feats = batched_inputs[kfg.ATT_FEATS] att_masks = batched_inputs[kfg.ATT_MASKS] global_feats = batched_inputs.get(kfg.GLOBAL_FEATS, None) if global_feats is None: if att_masks is None: global_feats = torch.mean(att_feats, 1) else: att_feats_masks = att_feats * att_masks.unsqueeze(-1) att_masks_sum = att_masks.sum(-1) global_feats = att_feats_masks.sum(1) / att_masks_sum.unsqueeze(-1) ret.update({ kfg.GLOBAL_FEATS: global_feats }) return ret