summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py24
1 files changed, 7 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py
index 58c833e..2f0a392 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -300,22 +300,6 @@ class VlpnDataset(IterableDataset):
300 self.cache = {} 300 self.cache = {}
301 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() 301 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item()
302 302
303 def get_static_example(self, item: VlpnDataItem, item_index: int):
304 if item_index in self.cache:
305 return copy.copy(self.cache[item_index])
306
307 example = {}
308
309 example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt))
310 example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt)
311
312 if self.num_class_images != 0:
313 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt)
314
315 self.cache[item_index] = example
316
317 return example
318
319 def __len__(self): 303 def __len__(self):
320 return self.length_ 304 return self.length_
321 305
@@ -374,12 +358,18 @@ class VlpnDataset(IterableDataset):
374 item = self.items[item_index] 358 item = self.items[item_index]
375 mask[self.bucket_item_range[bucket_mask][0]] = False 359 mask[self.bucket_item_range[bucket_mask][0]] = False
376 360
377 example = self.get_static_example(item, item_index) 361 example = {}
362
363 example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt))
364 example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt)
365
378 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 366 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
379 keywords_to_prompt(item.prompt, self.dropout, True) 367 keywords_to_prompt(item.prompt, self.dropout, True)
380 ) 368 )
381 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 369 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
370
382 if self.num_class_images != 0: 371 if self.num_class_images != 0:
372 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt)
383 example["class_images"] = image_transforms(get_image(item.class_image_path)) 373 example["class_images"] = image_transforms(get_image(item.class_image_path))
384 374
385 batch.append(example) 375 batch.append(example)