From 0930dae055d9f5cbedcd93c6ddef365538fe69e0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 20:50:28 +0100 Subject: Cache token IDs in dataset --- data/csv.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index d9f9db8..58c833e 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,6 +1,7 @@ import math import torch import json +import copy from pathlib import Path from typing import NamedTuple, Optional, Union, Callable @@ -296,8 +297,25 @@ class VlpnDataset(IterableDataset): self.bucket_item_range = torch.arange(len(self.bucket_items)) + self.cache = {} self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() + def get_static_example(self, item: VlpnDataItem, item_index: int): + if item_index in self.cache: + return copy.copy(self.cache[item_index]) + + example = {} + + example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) + example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) + + if self.num_class_images != 0: + example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) + + self.cache[item_index] = example + + return example + def __len__(self): return self.length_ @@ -356,19 +374,13 @@ class VlpnDataset(IterableDataset): item = self.items[item_index] mask[self.bucket_item_range[bucket_mask][0]] = False - example = {} - - example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) - example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) - - example["instance_images"] = image_transforms(get_image(item.instance_image_path)) + example = self.get_static_example(item, item_index) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( keywords_to_prompt(item.prompt, self.dropout, True) ) - + example["instance_images"] = image_transforms(get_image(item.instance_image_path)) if self.num_class_images != 0: example["class_images"] = image_transforms(get_image(item.class_image_path)) - example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) batch.append(example) -- cgit v1.2.3-70-g09d2