Source code for xmodaler.datasets.images.flickr30k

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: yehaoli.sysu@gmail.com
"""
import os
import copy
import pickle
import jsonlines
import numpy as np
from xmodaler.config import configurable
from xmodaler.config import kfg
from xmodaler.functional import dict_as_tensor, read_np_bbox
from xmodaler.tokenization import BertTokenizer
from ..build import DATASETS_REGISTRY

__all__ = ["Flickr30kDataset"]

[docs]@DATASETS_REGISTRY.register() class Flickr30kDataset:
[docs] @configurable def __init__( self, stage: str, anno_folder: str, anno_file: str, feats_folder: str, max_feat_num: int, max_seq_len: int, use_global_v: bool, tokenizer ): self.stage = stage self.anno_folder = anno_folder self.anno_file = anno_file self.feats_folder = feats_folder self.max_feat_num = max_feat_num self.max_seq_len = max_seq_len self.use_global_v = use_global_v self.tokenizer = tokenizer
[docs] @classmethod def from_config(cls, cfg, stage: str = "train"): ann_files = { "train": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "all_data_final_train_2014.jsonline"), "val": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "all_data_final_val_set0_2014.jsonline"), "test": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "all_data_final_test_set0_2014.jsonline") } ret = { "stage": stage, "anno_folder": cfg.DATALOADER.ANNO_FOLDER, "anno_file": ann_files[stage], "feats_folder": cfg.DATALOADER.FEATS_FOLDER, "max_feat_num": cfg.DATALOADER.MAX_FEAT_NUM, "max_seq_len": cfg.MODEL.MAX_SEQ_LEN, "use_global_v": cfg.DATALOADER.USE_GLOBAL_V, "tokenizer": BertTokenizer.from_pretrained(cfg.MODEL.PRETRAINING.MODEL_NAME, do_lower_case=cfg.MODEL.PRETRAINING.DO_LOWER_CASE) } return ret
def load_raw_data(self, cfg): datalist = [] with jsonlines.open(self.anno_file) as reader: for annotation in reader: sentences = annotation["sentences"] image_id = annotation["img_path"].split(".")[0] if self.stage == "train": for sent in sentences: datalist.append({ "image_id": image_id, "captions": sent }) else: datalist.append({ "image_id": image_id, "captions": sentences }) return datalist
[docs] def load_data(self, cfg): cache_path = os.path.join( self.anno_folder, "cache", "RetrievalFlickr30k_%s_%d.pkl" % (self.stage, self.max_seq_len) ) if not os.path.exists(cache_path): datalist = self.load_raw_data(cfg) self.tokenize(datalist) pickle.dump(datalist, open(cache_path, "wb")) datalist = pickle.load(open(cache_path, "rb")) return datalist
def tokenize(self, datalist): for entry in datalist: captions = entry["captions"] if isinstance(captions, list): tokens_arr = [] for caption in captions: tokens = self.tokenizer.encode(caption) tokens = tokens[: self.max_seq_len - 2] tokens = self.tokenizer.add_special_tokens_single_sentence(tokens) tokens_arr.append(tokens) entry["captions"] = tokens_arr else: tokens = self.tokenizer.encode(captions) tokens = tokens[: self.max_seq_len - 2] tokens = self.tokenizer.add_special_tokens_single_sentence(tokens) entry["captions"] = tokens def load_img_feat(self, image_id): image_path = os.path.join(self.feats_folder, image_id + ".npz") features, image_locations = read_np_bbox(image_path, self.max_feat_num, self.use_global_v) return features, image_locations def format_cap(self, caption): u_tokens_type = np.array([0] * len(caption)).astype(np.int64) caption = np.array(caption).astype(np.int64) return caption, u_tokens_type
[docs] def __call__(self, dataset_dict): dataset_dict = copy.deepcopy(dataset_dict) image_id = dataset_dict['image_id'] image_path = os.path.join(self.feats_folder, image_id + ".npz") features, image_locations = read_np_bbox(image_path, self.max_feat_num, self.use_global_v) captions = dataset_dict['captions'] if self.stage == "train": u_tokens_type = np.array([0] * len(captions)).astype(np.int64) captions = np.array(captions).astype(np.int64) ids = image_id else: ids = [image_id, [image_id] * len(captions)] u_tokens_type = [ np.array([0] * len(caption)).astype(np.int64) for caption in captions ] captions = [np.array(caption).astype(np.int64) for caption in captions] ret = { kfg.ATT_FEATS: features.astype('float32'), kfg.ATT_FEATS_LOC: image_locations.astype('float32'), kfg.U_TOKENS_IDS: captions, kfg.U_TOKENS_TYPE: u_tokens_type, } dict_as_tensor(ret) ret.update({ kfg.IDS: ids }) return ret