Source code for xmodaler.modeling.decode_strategy.decode_strategy

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li, Jianjie Luo
@contact: yehaoli.sysu@gmail.com, jianjieluo.sysu@gmail.com
"""
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from xmodaler.config import configurable
from xmodaler.config import kfg
from xmodaler.functional import load_vocab, decode_sequence, decode_sequence_bert
from xmodaler.tokenization import BertTokenizer

[docs]class DecodeStrategy(nn.Module, metaclass=ABCMeta):
[docs] @configurable def __init__( self, *, vocab_path, beam_size, max_seq_len, bert_tokenizer, bos_token_id, eos_token_id ): super().__init__() self.beam_size = beam_size if bert_tokenizer is None: self.vocab = load_vocab(vocab_path) else: self.vocab = None self.max_seq_len = max_seq_len self.bert_tokenizer = bert_tokenizer self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id
[docs] @classmethod def from_config(cls, cfg): tokenizer_map = { 'BERT': BertTokenizer } tokenizer_cls = tokenizer_map.get(cfg.INFERENCE.VOCAB, None) if tokenizer_cls is None: bert_tokenizer = None bos_token_id = 0 eos_token_id = 0 else: bert_tokenizer = tokenizer_cls.from_pretrained(cfg.MODEL.PRETRAINING.MODEL_NAME, do_lower_case=cfg.MODEL.PRETRAINING.DO_LOWER_CASE) if cfg.INFERENCE.VOCAB == 'BERT': bos_token_id = bert_tokenizer.vocab["[CLS]"] eos_token_id = bert_tokenizer.vocab["[SEP]"] return { "vocab_path": cfg.INFERENCE.VOCAB, "beam_size": cfg.DECODE_STRATEGY.BEAM_SIZE, "max_seq_len": cfg.MODEL.MAX_SEQ_LEN, 'bert_tokenizer': bert_tokenizer, "bos_token_id": bos_token_id, "eos_token_id": eos_token_id }
[docs] @abstractmethod def _forward(self, batched_inputs, model): pass
[docs] def forward(self, batched_inputs, output_sents, model): ret = self._forward(batched_inputs, model) if output_sents: if self.vocab: sents = decode_sequence(self.vocab, ret[kfg.G_SENTS_IDS]) else: sents = decode_sequence_bert(self.bert_tokenizer, ret[kfg.G_SENTS_IDS], self.eos_token_id) ret.update({ kfg.OUTPUT: sents }) return ret