diff options
-rw-r--r-- | data/csv.py | 55 |
1 files changed, 26 insertions, 29 deletions
diff --git a/data/csv.py b/data/csv.py index 480e9f2..233f5d8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -399,23 +399,20 @@ class VlpnDataset(IterableDataset): | |||
399 | mask[:start] = False | 399 | mask[:start] = False |
400 | mask[end:] = False | 400 | mask[end:] = False |
401 | 401 | ||
402 | while mask.any(): | 402 | while mask.any() or len(batch) != 0: |
403 | if len(batch) >= batch_size: | ||
404 | yield batch | ||
405 | batch = [] | ||
406 | |||
403 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) | 407 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) |
404 | bucket_items = self.bucket_items[bucket_mask] | 408 | bucket_items = self.bucket_items[bucket_mask] |
405 | 409 | ||
406 | if len(batch) >= batch_size: | 410 | if len(bucket_items) == 0 and len(batch) != 0 and not self.fill_batch: |
407 | yield batch | 411 | yield batch |
408 | batch = [] | 412 | batch = [] |
413 | continue | ||
409 | 414 | ||
410 | if len(bucket_items) == 0: | 415 | if len(bucket_items) == 0 and len(batch) == 0: |
411 | if len(batch) != 0: | ||
412 | if self.fill_batch: | ||
413 | fill_items = self.bucket_items[self.bucket_assignments == bucket] | ||
414 | fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) | ||
415 | batch += fill_items[fill_perm] | ||
416 | yield batch | ||
417 | batch = [] | ||
418 | |||
419 | bucket = self.bucket_assignments[mask][0] | 416 | bucket = self.bucket_assignments[mask][0] |
420 | ratio = self.buckets[bucket] | 417 | ratio = self.buckets[bucket] |
421 | width = int(self.size * ratio) if ratio > 1 else self.size | 418 | width = int(self.size * ratio) if ratio > 1 else self.size |
@@ -430,30 +427,30 @@ class VlpnDataset(IterableDataset): | |||
430 | transforms.Normalize([0.5], [0.5]), | 427 | transforms.Normalize([0.5], [0.5]), |
431 | ] | 428 | ] |
432 | ) | 429 | ) |
430 | |||
431 | continue | ||
432 | |||
433 | if len(bucket_items) == 0: | ||
434 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] | ||
435 | item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] | ||
433 | else: | 436 | else: |
434 | item_index = bucket_items[0] | 437 | item_index = bucket_items[0] |
435 | item = self.items[item_index] | ||
436 | mask[self.bucket_item_range[bucket_mask][0]] = False | 438 | mask[self.bucket_item_range[bucket_mask][0]] = False |
437 | 439 | ||
438 | example = {} | 440 | item = self.items[item_index] |
439 | 441 | ||
440 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) | 442 | example = {} |
441 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | ||
442 | 443 | ||
443 | example["instance_prompt_ids"] = self.get_input_ids( | 444 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
444 | keywords_to_prompt(item.prompt, self.dropout, True) | 445 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
445 | ) | ||
446 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | ||
447 | 446 | ||
448 | if self.num_class_images != 0: | 447 | example["instance_prompt_ids"] = self.get_input_ids( |
449 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) | 448 | keywords_to_prompt(item.prompt, self.dropout, True) |
450 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 449 | ) |
450 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | ||
451 | 451 | ||
452 | batch.append(example) | 452 | if self.num_class_images != 0: |
453 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) | ||
454 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | ||
453 | 455 | ||
454 | if len(batch) != 0: | 456 | batch.append(example) |
455 | if self.fill_batch: | ||
456 | fill_items = self.bucket_items[self.bucket_assignments == bucket] | ||
457 | fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) | ||
458 | batch += fill_items[fill_perm] | ||
459 | yield batch | ||