summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py55
1 files changed, 26 insertions, 29 deletions
diff --git a/data/csv.py b/data/csv.py
index 480e9f2..233f5d8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -399,23 +399,20 @@ class VlpnDataset(IterableDataset):
399 mask[:start] = False 399 mask[:start] = False
400 mask[end:] = False 400 mask[end:] = False
401 401
402 while mask.any(): 402 while mask.any() or len(batch) != 0:
403 if len(batch) >= batch_size:
404 yield batch
405 batch = []
406
403 bucket_mask = mask.logical_and(self.bucket_assignments == bucket) 407 bucket_mask = mask.logical_and(self.bucket_assignments == bucket)
404 bucket_items = self.bucket_items[bucket_mask] 408 bucket_items = self.bucket_items[bucket_mask]
405 409
406 if len(batch) >= batch_size: 410 if len(bucket_items) == 0 and len(batch) != 0 and not self.fill_batch:
407 yield batch 411 yield batch
408 batch = [] 412 batch = []
413 continue
409 414
410 if len(bucket_items) == 0: 415 if len(bucket_items) == 0 and len(batch) == 0:
411 if len(batch) != 0:
412 if self.fill_batch:
413 fill_items = self.bucket_items[self.bucket_assignments == bucket]
414 fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator)
415 batch += fill_items[fill_perm]
416 yield batch
417 batch = []
418
419 bucket = self.bucket_assignments[mask][0] 416 bucket = self.bucket_assignments[mask][0]
420 ratio = self.buckets[bucket] 417 ratio = self.buckets[bucket]
421 width = int(self.size * ratio) if ratio > 1 else self.size 418 width = int(self.size * ratio) if ratio > 1 else self.size
@@ -430,30 +427,30 @@ class VlpnDataset(IterableDataset):
430 transforms.Normalize([0.5], [0.5]), 427 transforms.Normalize([0.5], [0.5]),
431 ] 428 ]
432 ) 429 )
430
431 continue
432
433 if len(bucket_items) == 0:
434 bucket_items = self.bucket_items[self.bucket_assignments == bucket]
435 item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)]
433 else: 436 else:
434 item_index = bucket_items[0] 437 item_index = bucket_items[0]
435 item = self.items[item_index]
436 mask[self.bucket_item_range[bucket_mask][0]] = False 438 mask[self.bucket_item_range[bucket_mask][0]] = False
437 439
438 example = {} 440 item = self.items[item_index]
439 441
440 example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) 442 example = {}
441 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
442 443
443 example["instance_prompt_ids"] = self.get_input_ids( 444 example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt))
444 keywords_to_prompt(item.prompt, self.dropout, True) 445 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
445 )
446 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
447 446
448 if self.num_class_images != 0: 447 example["instance_prompt_ids"] = self.get_input_ids(
449 example["class_prompt_ids"] = self.get_input_ids(item.cprompt) 448 keywords_to_prompt(item.prompt, self.dropout, True)
450 example["class_images"] = image_transforms(get_image(item.class_image_path)) 449 )
450 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
451 451
452 batch.append(example) 452 if self.num_class_images != 0:
453 example["class_prompt_ids"] = self.get_input_ids(item.cprompt)
454 example["class_images"] = image_transforms(get_image(item.class_image_path))
453 455
454 if len(batch) != 0: 456 batch.append(example)
455 if self.fill_batch:
456 fill_items = self.bucket_items[self.bucket_assignments == bucket]
457 fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator)
458 batch += fill_items[fill_perm]
459 yield batch