diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/data/csv.py b/data/csv.py index d9f9db8..58c833e 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | import copy | ||
4 | from pathlib import Path | 5 | from pathlib import Path |
5 | from typing import NamedTuple, Optional, Union, Callable | 6 | from typing import NamedTuple, Optional, Union, Callable |
6 | 7 | ||
@@ -296,8 +297,25 @@ class VlpnDataset(IterableDataset): | |||
296 | 297 | ||
297 | self.bucket_item_range = torch.arange(len(self.bucket_items)) | 298 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
298 | 299 | ||
300 | self.cache = {} | ||
299 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 301 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
300 | 302 | ||
303 | def get_static_example(self, item: VlpnDataItem, item_index: int): | ||
304 | if item_index in self.cache: | ||
305 | return copy.copy(self.cache[item_index]) | ||
306 | |||
307 | example = {} | ||
308 | |||
309 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) | ||
310 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) | ||
311 | |||
312 | if self.num_class_images != 0: | ||
313 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) | ||
314 | |||
315 | self.cache[item_index] = example | ||
316 | |||
317 | return example | ||
318 | |||
301 | def __len__(self): | 319 | def __len__(self): |
302 | return self.length_ | 320 | return self.length_ |
303 | 321 | ||
@@ -356,19 +374,13 @@ class VlpnDataset(IterableDataset): | |||
356 | item = self.items[item_index] | 374 | item = self.items[item_index] |
357 | mask[self.bucket_item_range[bucket_mask][0]] = False | 375 | mask[self.bucket_item_range[bucket_mask][0]] = False |
358 | 376 | ||
359 | example = {} | 377 | example = self.get_static_example(item, item_index) |
360 | |||
361 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) | ||
362 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) | ||
363 | |||
364 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | ||
365 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 378 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
366 | keywords_to_prompt(item.prompt, self.dropout, True) | 379 | keywords_to_prompt(item.prompt, self.dropout, True) |
367 | ) | 380 | ) |
368 | 381 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | |
369 | if self.num_class_images != 0: | 382 | if self.num_class_images != 0: |
370 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 383 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
371 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) | ||
372 | 384 | ||
373 | batch.append(example) | 385 | batch.append(example) |
374 | 386 | ||