From b2c3389e9c6375d9081625e75a99de98395f8e77 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 1 Nov 2022 16:19:01 +0100 Subject: Update --- data/csv.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 6bd7f9b..793fbf8 100644 --- a/data/csv.py +++ b/data/csv.py @@ -150,7 +150,6 @@ class CSVDataset(Dataset): self.class_identifier = class_identifier self.num_class_images = num_class_images self.image_cache = {} - self.input_id_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats @@ -185,15 +184,7 @@ class CSVDataset(Dataset): return image def get_input_ids(self, prompt, identifier): - prompt = prompt.format(identifier) - - if prompt in self.input_id_cache: - return self.input_id_cache[prompt] - - input_ids = self.prompt_processor.get_input_ids(prompt) - self.input_id_cache[prompt] = input_ids - - return input_ids + return self.prompt_processor.get_input_ids(prompt.format(identifier)) def get_example(self, i): item = self.data[i % self.num_instance_images] -- cgit v1.2.3-70-g09d2