summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-08 20:50:28 +0100
committerVolpeon <git@volpeon.ink>2023-01-08 20:50:28 +0100
commit0930dae055d9f5cbedcd93c6ddef365538fe69e0 (patch)
tree024420abba506626976c8c84b2cadde89dcc4ff7
parentFix (diff)
downloadtextual-inversion-diff-0930dae055d9f5cbedcd93c6ddef365538fe69e0.tar.gz
textual-inversion-diff-0930dae055d9f5cbedcd93c6ddef365538fe69e0.tar.bz2
textual-inversion-diff-0930dae055d9f5cbedcd93c6ddef365538fe69e0.zip
Cache token IDs in dataset
-rw-r--r--data/csv.py28
1 files changed, 20 insertions, 8 deletions
diff --git a/data/csv.py b/data/csv.py
index d9f9db8..58c833e 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,6 +1,7 @@
1import math 1import math
2import torch 2import torch
3import json 3import json
4import copy
4from pathlib import Path 5from pathlib import Path
5from typing import NamedTuple, Optional, Union, Callable 6from typing import NamedTuple, Optional, Union, Callable
6 7
@@ -296,8 +297,25 @@ class VlpnDataset(IterableDataset):
296 297
297 self.bucket_item_range = torch.arange(len(self.bucket_items)) 298 self.bucket_item_range = torch.arange(len(self.bucket_items))
298 299
300 self.cache = {}
299 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() 301 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item()
300 302
303 def get_static_example(self, item: VlpnDataItem, item_index: int):
304 if item_index in self.cache:
305 return copy.copy(self.cache[item_index])
306
307 example = {}
308
309 example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt))
310 example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt)
311
312 if self.num_class_images != 0:
313 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt)
314
315 self.cache[item_index] = example
316
317 return example
318
301 def __len__(self): 319 def __len__(self):
302 return self.length_ 320 return self.length_
303 321
@@ -356,19 +374,13 @@ class VlpnDataset(IterableDataset):
356 item = self.items[item_index] 374 item = self.items[item_index]
357 mask[self.bucket_item_range[bucket_mask][0]] = False 375 mask[self.bucket_item_range[bucket_mask][0]] = False
358 376
359 example = {} 377 example = self.get_static_example(item, item_index)
360
361 example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt))
362 example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt)
363
364 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
365 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 378 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
366 keywords_to_prompt(item.prompt, self.dropout, True) 379 keywords_to_prompt(item.prompt, self.dropout, True)
367 ) 380 )
368 381 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
369 if self.num_class_images != 0: 382 if self.num_class_images != 0:
370 example["class_images"] = image_transforms(get_image(item.class_image_path)) 383 example["class_images"] = image_transforms(get_image(item.class_image_path))
371 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt)
372 384
373 batch.append(example) 385 batch.append(example)
374 386