Source code for xmodaler.datasets.images.flickr30k_single_stream

# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Jianjie Luo
@contact: jianjieluo.sysu@gmail.com
"""
import copy
import sys
import torch
import random
import numpy as np
from collections import defaultdict

from xmodaler.config import configurable
from xmodaler.config import kfg
from xmodaler.functional import dict_as_tensor, flat_list_of_lists, pad_tensor, clip_v_inputs
from .flickr30k import Flickr30kDataset
from ..build import DATASETS_REGISTRY

__all__ = ["Flickr30kDatasetForSingleStream", "Flickr30kDatasetForSingleStreamVal"]


[docs]@DATASETS_REGISTRY.register() class Flickr30kDatasetForSingleStream(Flickr30kDataset):
[docs] @configurable def __init__( self, stage: str, anno_folder: str, anno_file: str, feats_folder: str, max_feat_num: int, max_seq_len: int, use_global_v: bool, negative_size: int, tokenizer, cfg ): super(Flickr30kDatasetForSingleStream, self).__init__( stage, anno_folder, anno_file, feats_folder, max_feat_num, max_seq_len, use_global_v, tokenizer ) self.negative_size = negative_size # load img_ids for neg sample datalist = self.load_data(cfg) self.imgid2caps = defaultdict(set) self.cap2imgids = defaultdict(set) for item in datalist: image_id = item['image_id'] caption = tuple(item['captions']) # NOTE: actually it is one caption self.imgid2caps[image_id].add(caption) self.cap2imgids[caption].add(image_id) self.imgid2caps = {k:list(v) for k,v in dict(self.imgid2caps).items()} self.cap2imgids = {k:list(v) for k,v in dict(self.cap2imgids).items()} self.image_ids_set = set(list(self.imgid2caps.keys()))
[docs] @classmethod def from_config(cls, cfg, stage: str = "train"): ret = super().from_config(cfg, stage) ret['negative_size'] = cfg.DATALOADER.NEGATIVE_SIZE ret['cfg'] = cfg return ret
def sample_neg_pairs(self, pos_meta_data, neg_meta_data_list): (image_id, features1, image_locations1, caption1, u_tokens_type1) = pos_meta_data neg_features_list, neg_image_locations_list, neg_caption_list, neg_u_tokens_type_list = neg_meta_data_list black_img_ids = flat_list_of_lists([self.cap2imgids[c] for c in self.imgid2caps[image_id]]) image_id_pool = list(self.image_ids_set - set(black_img_ids)) # sample a cap wrong img_id2 = random.choice(image_id_pool) features2, image_locations2 = features1, image_locations1 caption2 = random.choice(self.imgid2caps[img_id2]) caption2, u_tokens_type2 = self.format_cap(caption2) # sample an img wrong img_id3 = random.choice(image_id_pool) features3, image_locations3 = self.load_img_feat(img_id3) caption3, u_tokens_type3 = caption1, u_tokens_type1 # add neg sample neg_features_list.extend([features2, features3]) neg_image_locations_list.extend([image_locations2, image_locations3]) neg_caption_list.extend([caption2, caption3]) neg_u_tokens_type_list.extend([u_tokens_type2, u_tokens_type3])
[docs] def __call__(self, dataset_dict): assert self.stage == 'train' dataset_dict = copy.deepcopy(dataset_dict) image_id = dataset_dict['image_id'] # Positive features1, image_locations1 = self.load_img_feat(image_id) caption1 = dataset_dict['captions'] caption1, u_tokens_type1 = self.format_cap(caption1) pos_meta_data = (image_id, features1, image_locations1, caption1, u_tokens_type1) neg_features_list = [] neg_image_locations_list = [] neg_caption_list = [] neg_u_tokens_type_list = [] neg_meta_data_list = [neg_features_list, neg_image_locations_list, neg_caption_list, neg_u_tokens_type_list] for _ in range(self.negative_size): # negative samples. # 1: correct one, 2: random caption wrong, 3: random image wrong. # self.negative_size pair <==> 2*self.negative_size negatives pair self.sample_neg_pairs(pos_meta_data, neg_meta_data_list) features = [features1] + neg_features_list image_locations = [image_locations1] + neg_image_locations_list captions = [caption1] + neg_caption_list u_tokens_type = [u_tokens_type1] + neg_u_tokens_type_list ret = { kfg.ATT_FEATS: [x.astype('float32') for x in features], kfg.ATT_FEATS_LOC: [x.astype('float32') for x in image_locations], kfg.U_TOKENS_IDS: captions, kfg.U_TOKENS_TYPE: u_tokens_type, kfg.U_TARGET_IDS: np.array([0], dtype=np.int64).reshape(-1, 1) } dict_as_tensor(ret) ret[kfg.SAMPLE_PER_SAMPLE] = len(features) return ret
[docs]@DATASETS_REGISTRY.register() class Flickr30kDatasetForSingleStreamVal(Flickr30kDataset):
[docs] @configurable def __init__( self, stage: str, anno_folder: str, anno_file: str, feats_folder: str, max_feat_num: int, max_seq_len: int, use_global_v: bool, inf_batch_size: int, tokenizer, cfg ): super(Flickr30kDatasetForSingleStreamVal, self).__init__( stage, anno_folder, anno_file, feats_folder, max_feat_num, max_seq_len, use_global_v, tokenizer ) self.inf_batch_size = inf_batch_size # load img_ids for neg sample datalist = super().load_data(cfg) self.imgid2caps = {item['image_id']:item['captions'] for item in datalist} self.all_img_ids = list(self.imgid2caps.keys()) self.imgid2featidx = {i:j for j,i in enumerate(self.all_img_ids)} tid2imgid = {} imgid2tids = defaultdict(list) caption_all = [] tid = 0 for image_id, captions in self.imgid2caps.items(): sent_num = len(captions) for i, caption in enumerate(captions): curr_tid = tid + i caption_all.append(caption) tid2imgid[curr_tid] = image_id imgid2tids[image_id].append(curr_tid) tid += sent_num self.tid2imgid = tid2imgid self.imgid2tids = dict(imgid2tids) self.caption_all = caption_all # load v_feature pool features_all = [] image_locations_all = [] for image_id, feat_idx in self.imgid2featidx.items(): features, image_locations = self.load_img_feat(image_id) features_all.append(torch.as_tensor(features).float()) image_locations_all.append(torch.as_tensor(image_locations).float()) sys.stdout.write('%d/%d\r' % (feat_idx, len(self.all_img_ids))) sys.stdout.flush() vfeats_all, vmasks_all = pad_tensor(features_all, padding_value=0, use_mask=True) img_loc_all = pad_tensor(image_locations_all, padding_value=0, use_mask=False) self.features_all = vfeats_all.float() self.image_mask_all = vmasks_all.float() self.image_locations_all = img_loc_all.float()
[docs] @classmethod def from_config(cls, cfg, stage: str = "train"): ret = super().from_config(cfg, stage) ret['inf_batch_size'] = cfg.DATALOADER.INF_BATCH_SIZE ret['cfg'] = cfg return ret
[docs] def load_data(self, cfg): # sample by text datalist = [] for tid, caption in enumerate(self.caption_all): datalist.append({ 'tid': tid, 'caption': caption, 'tid2imgid': self.tid2imgid[tid], 'imgid2tids': tuple(self.imgid2tids[self.tid2imgid[tid]]), 'total_img_num': len(self.all_img_ids) }) # datalist = datalist[:20] # for debug return datalist
[docs] def __call__(self, dataset_dict): assert self.stage != 'train' dataset_dict = copy.deepcopy(dataset_dict) # NOTE: Only support text->image retrieval now tid = dataset_dict['tid'] matched_imgid = dataset_dict['tid2imgid'] matched_imgfeatidx = self.imgid2featidx[matched_imgid] caption = dataset_dict['caption'] total_img_num = dataset_dict['total_img_num'] # prepare txt pool caption, u_tokens_type = self.format_cap(caption) tokens_masks = np.array([1] * len(caption), dtype=np.int64) u_tokens_ids = torch.tensor(caption).long() u_tokens_type = torch.tensor(u_tokens_type).long() tokens_masks = torch.tensor(tokens_masks).long() u_tokens_ids_pool = u_tokens_ids.unsqueeze(0).expand(total_img_num, -1) u_tokens_type_pool = u_tokens_type.unsqueeze(0).expand(total_img_num, -1) tokens_masks_pool = tokens_masks.unsqueeze(0).expand(total_img_num, -1) # prepare img pool img_feats_pool, image_locations_pool, image_mask_pool = self.features_all.clone(), self.image_locations_all.clone(), self.image_mask_all.clone() # chunk to minibatch u_tokens_ids_pool = torch.split(u_tokens_ids_pool, self.inf_batch_size, dim=0) u_tokens_type_pool = torch.split(u_tokens_type_pool, self.inf_batch_size, dim=0) tokens_masks_pool = torch.split(tokens_masks_pool, self.inf_batch_size, dim=0) img_feats_pool = torch.split(img_feats_pool, self.inf_batch_size, dim=0) image_locations_pool = torch.split(image_locations_pool, self.inf_batch_size, dim=0) image_mask_pool = torch.split(image_mask_pool, self.inf_batch_size, dim=0) # preprocess in the dataset batches = [] for u_tokens_ids, u_tokens_type, tokens_masks, img_feats, image_locations, image_mask in \ zip(u_tokens_ids_pool, u_tokens_type_pool, tokens_masks_pool, img_feats_pool, image_locations_pool, image_mask_pool): img_feats, image_locations, image_mask = clip_v_inputs(img_feats, image_locations, image_mask) batch = { kfg.ATT_FEATS: img_feats, kfg.ATT_FEATS_LOC: image_locations, kfg.ATT_MASKS: image_mask, kfg.U_TOKENS_IDS: u_tokens_ids, kfg.TOKENS_MASKS: tokens_masks, kfg.U_TOKENS_TYPE: u_tokens_type, } batch['matched_imgfeatidx'] = matched_imgfeatidx batch['total_img_num'] = total_img_num dict_as_tensor(batch) batches.append(batch) return batches