Source code for xmodaler.modeling.decoder.attribute_decoder

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
from torch import nn

from xmodaler.config import configurable
from xmodaler.config import CfgNode as CN
from xmodaler.config import kfg
from ..layers.base_attention import BaseAttention
from .decoder import Decoder
from .build import DECODER_REGISTRY

__all__ = ["AttributeDecoder"]

[docs]@DECODER_REGISTRY.register() class AttributeDecoder(Decoder):
[docs] @configurable def __init__( self, *, hidden_size: int, token_embed_dim: int, visual_feat_dim: int, attribute_dim: int, dropout: float ): super(AttributeDecoder, self).__init__() self.num_layers = 1 self.hidden_size = hidden_size self.attribute_fc = nn.Linear(attribute_dim, hidden_size) self.vfeat_fc = nn.Linear(visual_feat_dim, hidden_size) self.lstm = nn.LSTMCell(hidden_size, hidden_size) self.dropout = nn.Dropout(dropout) if dropout > 0 else None
[docs] @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.DECODER_DIM, "token_embed_dim": cfg.MODEL.TOKEN_EMBED.DIM, "visual_feat_dim": cfg.MODEL.VISUAL_EMBED.IN_DIM, "attribute_dim": cfg.MODEL.LSTMA.ATTRIBUTE_DIM, "dropout": cfg.MODEL.LSTMA.DROPOUT }
[docs] @classmethod def add_config(cls, cfg): cfg.MODEL.LSTMA = CN() cfg.MODEL.LSTMA.ATTRIBUTE_DIM = 1000 cfg.MODEL.LSTMA.DROPOUT = 0.5
[docs] def preprocess(self, batched_inputs): attributes = batched_inputs[kfg.ATTRIBUTE] gv_feats = batched_inputs[kfg.GLOBAL_FEATS] init_states = self.init_states(attributes.shape[0]) hidden_states = init_states[kfg.G_HIDDEN_STATES] cell_states = init_states[kfg.G_CELL_STATES] # t = -2 p_attributes = self.attribute_fc(attributes) if self.dropout is not None: p_attributes = self.dropout(p_attributes) h1_a, c1_a = self.lstm(p_attributes, (hidden_states[0], cell_states[0])) # t = -1 p_gv_feats = self.vfeat_fc(gv_feats) if self.dropout is not None: p_gv_feats = self.dropout(p_gv_feats) h1_v, c1_v = self.lstm(p_gv_feats, (h1_a, c1_a)) batched_inputs.update({ kfg.G_HIDDEN_STATES: [h1_v], kfg.G_CELL_STATES: [c1_v] }) return batched_inputs
[docs] def forward(self, batched_inputs): wt = batched_inputs[kfg.G_TOKEN_EMBED] hidden_states = batched_inputs[kfg.G_HIDDEN_STATES] cell_states = batched_inputs[kfg.G_CELL_STATES] if self.dropout is not None: wt = self.dropout(wt) h1_t, c1_t = self.lstm(wt, (hidden_states[0], cell_states[0])) hidden_states = [h1_t] cell_states = [c1_t] return { kfg.G_HIDDEN_STATES: hidden_states, kfg.G_CELL_STATES: cell_states }