diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 11 |
1 files changed, 1 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py index 6bd7f9b..793fbf8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -150,7 +150,6 @@ class CSVDataset(Dataset): | |||
| 150 | self.class_identifier = class_identifier | 150 | self.class_identifier = class_identifier |
| 151 | self.num_class_images = num_class_images | 151 | self.num_class_images = num_class_images |
| 152 | self.image_cache = {} | 152 | self.image_cache = {} |
| 153 | self.input_id_cache = {} | ||
| 154 | 153 | ||
| 155 | self.num_instance_images = len(self.data) | 154 | self.num_instance_images = len(self.data) |
| 156 | self._length = self.num_instance_images * repeats | 155 | self._length = self.num_instance_images * repeats |
| @@ -185,15 +184,7 @@ class CSVDataset(Dataset): | |||
| 185 | return image | 184 | return image |
| 186 | 185 | ||
| 187 | def get_input_ids(self, prompt, identifier): | 186 | def get_input_ids(self, prompt, identifier): |
| 188 | prompt = prompt.format(identifier) | 187 | return self.prompt_processor.get_input_ids(prompt.format(identifier)) |
| 189 | |||
| 190 | if prompt in self.input_id_cache: | ||
| 191 | return self.input_id_cache[prompt] | ||
| 192 | |||
| 193 | input_ids = self.prompt_processor.get_input_ids(prompt) | ||
| 194 | self.input_id_cache[prompt] = input_ids | ||
| 195 | |||
| 196 | return input_ids | ||
| 197 | 188 | ||
| 198 | def get_example(self, i): | 189 | def get_example(self, i): |
| 199 | item = self.data[i % self.num_instance_images] | 190 | item = self.data[i % self.num_instance_images] |
