Source code for xmodaler.scorer.bert_tokenized_scorer

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Jianjie Luo
@contact: jianjieluo.sysu@gmail.com
"""
import numpy as np
import pickle

from xmodaler.config import configurable
from xmodaler.config import kfg
from xmodaler.tokenization import BertTokenizer
from .build import SCORER_REGISTRY

__all__ = ['BertTokenizedScorer']

[docs]@SCORER_REGISTRY.register() class BertTokenizedScorer(object):
[docs] @configurable def __init__( self, *, types, scorers, weights, gt_path, bert_tokenizer ): self.types = types self.scorers = scorers self.weights = weights self.gts = pickle.load(open(gt_path, 'rb'), encoding='bytes') self.gts = {str(k):v for k,v in self.gts.items()} self.tokenizer = bert_tokenizer self.sep_token_id = self.tokenizer.vocab["[SEP]"]
[docs] @classmethod def from_config(cls, cfg): scorers = [] for name in cfg.SCORER.TYPES: scorers.append(SCORER_REGISTRY.get(name)(cfg)) return { 'scorers': scorers, 'types': cfg.SCORER.TYPES, 'weights': cfg.SCORER.WEIGHTS, 'gt_path': cfg.SCORER.GT_PATH, 'bert_tokenizer': BertTokenizer.from_pretrained(cfg.MODEL.PRETRAINING.MODEL_NAME, do_lower_case=cfg.MODEL.PRETRAINING.DO_LOWER_CASE) }
[docs] def get_sents(self, sent): words = [] for word in sent: if word == self.sep_token_id: words.append('.') break words.append(self.tokenizer.ids_to_tokens[word]) words = self.tokenizer.convert_tokens_to_string(words).split() return words
def __call__(self, batched_inputs): ids = batched_inputs[kfg.IDS] res = batched_inputs[kfg.G_SENTS_IDS] res = res.cpu().tolist() hypo = [self.get_sents(r) for r in res] gts = [self.gts[i] for i in ids] rewards_info = {} rewards = np.zeros(len(ids)) for i, scorer in enumerate(self.scorers): score, scores = scorer.compute_score(gts, hypo) rewards += self.weights[i] * scores rewards_info[self.types[i]] = score rewards_info.update({ kfg.REWARDS: rewards }) return rewards_info