diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-08 22:00:17 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-08 22:00:17 +0100 |
| commit | d1136102c218e9d478c764a32f9672c28f56077d (patch) | |
| tree | cf4e3814397199ad5b94e7be3fd1a1c2bbcf49fc /data/csv.py | |
| parent | Cache token IDs in dataset (diff) | |
| download | textual-inversion-diff-d1136102c218e9d478c764a32f9672c28f56077d.tar.gz textual-inversion-diff-d1136102c218e9d478c764a32f9672c28f56077d.tar.bz2 textual-inversion-diff-d1136102c218e9d478c764a32f9672c28f56077d.zip | |
No cache after all
Diffstat (limited to 'data/csv.py')
| -rw-r--r-- | data/csv.py | 24 |
1 files changed, 7 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py index 58c833e..2f0a392 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -300,22 +300,6 @@ class VlpnDataset(IterableDataset): | |||
| 300 | self.cache = {} | 300 | self.cache = {} |
| 301 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 301 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
| 302 | 302 | ||
| 303 | def get_static_example(self, item: VlpnDataItem, item_index: int): | ||
| 304 | if item_index in self.cache: | ||
| 305 | return copy.copy(self.cache[item_index]) | ||
| 306 | |||
| 307 | example = {} | ||
| 308 | |||
| 309 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) | ||
| 310 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) | ||
| 311 | |||
| 312 | if self.num_class_images != 0: | ||
| 313 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) | ||
| 314 | |||
| 315 | self.cache[item_index] = example | ||
| 316 | |||
| 317 | return example | ||
| 318 | |||
| 319 | def __len__(self): | 303 | def __len__(self): |
| 320 | return self.length_ | 304 | return self.length_ |
| 321 | 305 | ||
| @@ -374,12 +358,18 @@ class VlpnDataset(IterableDataset): | |||
| 374 | item = self.items[item_index] | 358 | item = self.items[item_index] |
| 375 | mask[self.bucket_item_range[bucket_mask][0]] = False | 359 | mask[self.bucket_item_range[bucket_mask][0]] = False |
| 376 | 360 | ||
| 377 | example = self.get_static_example(item, item_index) | 361 | example = {} |
| 362 | |||
| 363 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) | ||
| 364 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) | ||
| 365 | |||
| 378 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 366 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
| 379 | keywords_to_prompt(item.prompt, self.dropout, True) | 367 | keywords_to_prompt(item.prompt, self.dropout, True) |
| 380 | ) | 368 | ) |
| 381 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 369 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
| 370 | |||
| 382 | if self.num_class_images != 0: | 371 | if self.num_class_images != 0: |
| 372 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) | ||
| 383 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 373 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
| 384 | 374 | ||
| 385 | batch.append(example) | 375 | batch.append(example) |
