Source code for xmodaler.modeling.predictor.base_predictor

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import torch
from torch import nn

from xmodaler.config import configurable
from xmodaler.config import kfg
from .build import PREDICTOR_REGISTRY

__all__ = ["BasePredictor"]

[docs]@PREDICTOR_REGISTRY.register() class BasePredictor(nn.Module):
[docs] @configurable def __init__( self, *, hidden_size: int, vocab_size: int, # include <BOS>/<EOS> dropout: float ): super(BasePredictor, self).__init__() self.logits = nn.Linear(hidden_size, vocab_size) self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None
[docs] @classmethod def from_config(cls, cfg): return { "hidden_size": cfg.MODEL.DECODER_DIM, "vocab_size": cfg.MODEL.VOCAB_SIZE, "dropout": cfg.MODEL.PRED_DROPOUT }
[docs] @classmethod def add_config(cls, cfg): pass
[docs] def forward(self, batched_inputs): hidden_states = batched_inputs[kfg.G_HIDDEN_STATES] if isinstance(hidden_states, list): hidden_states = hidden_states[-1] if self.dropout: hidden_states = self.dropout(hidden_states) logits = self.logits(hidden_states) return { kfg.G_LOGITS: logits }