diff options
| -rw-r--r-- | data/csv.py | 55 |
1 files changed, 26 insertions, 29 deletions
diff --git a/data/csv.py b/data/csv.py index 480e9f2..233f5d8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -399,23 +399,20 @@ class VlpnDataset(IterableDataset): | |||
| 399 | mask[:start] = False | 399 | mask[:start] = False |
| 400 | mask[end:] = False | 400 | mask[end:] = False |
| 401 | 401 | ||
| 402 | while mask.any(): | 402 | while mask.any() or len(batch) != 0: |
| 403 | if len(batch) >= batch_size: | ||
| 404 | yield batch | ||
| 405 | batch = [] | ||
| 406 | |||
| 403 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) | 407 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) |
| 404 | bucket_items = self.bucket_items[bucket_mask] | 408 | bucket_items = self.bucket_items[bucket_mask] |
| 405 | 409 | ||
| 406 | if len(batch) >= batch_size: | 410 | if len(bucket_items) == 0 and len(batch) != 0 and not self.fill_batch: |
| 407 | yield batch | 411 | yield batch |
| 408 | batch = [] | 412 | batch = [] |
| 413 | continue | ||
| 409 | 414 | ||
| 410 | if len(bucket_items) == 0: | 415 | if len(bucket_items) == 0 and len(batch) == 0: |
| 411 | if len(batch) != 0: | ||
| 412 | if self.fill_batch: | ||
| 413 | fill_items = self.bucket_items[self.bucket_assignments == bucket] | ||
| 414 | fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) | ||
| 415 | batch += fill_items[fill_perm] | ||
| 416 | yield batch | ||
| 417 | batch = [] | ||
| 418 | |||
| 419 | bucket = self.bucket_assignments[mask][0] | 416 | bucket = self.bucket_assignments[mask][0] |
| 420 | ratio = self.buckets[bucket] | 417 | ratio = self.buckets[bucket] |
| 421 | width = int(self.size * ratio) if ratio > 1 else self.size | 418 | width = int(self.size * ratio) if ratio > 1 else self.size |
| @@ -430,30 +427,30 @@ class VlpnDataset(IterableDataset): | |||
| 430 | transforms.Normalize([0.5], [0.5]), | 427 | transforms.Normalize([0.5], [0.5]), |
| 431 | ] | 428 | ] |
| 432 | ) | 429 | ) |
| 430 | |||
| 431 | continue | ||
| 432 | |||
| 433 | if len(bucket_items) == 0: | ||
| 434 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] | ||
| 435 | item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] | ||
| 433 | else: | 436 | else: |
| 434 | item_index = bucket_items[0] | 437 | item_index = bucket_items[0] |
| 435 | item = self.items[item_index] | ||
| 436 | mask[self.bucket_item_range[bucket_mask][0]] = False | 438 | mask[self.bucket_item_range[bucket_mask][0]] = False |
| 437 | 439 | ||
| 438 | example = {} | 440 | item = self.items[item_index] |
| 439 | 441 | ||
| 440 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) | 442 | example = {} |
| 441 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | ||
| 442 | 443 | ||
| 443 | example["instance_prompt_ids"] = self.get_input_ids( | 444 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
| 444 | keywords_to_prompt(item.prompt, self.dropout, True) | 445 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
| 445 | ) | ||
| 446 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | ||
| 447 | 446 | ||
| 448 | if self.num_class_images != 0: | 447 | example["instance_prompt_ids"] = self.get_input_ids( |
| 449 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) | 448 | keywords_to_prompt(item.prompt, self.dropout, True) |
| 450 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 449 | ) |
| 450 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | ||
| 451 | 451 | ||
| 452 | batch.append(example) | 452 | if self.num_class_images != 0: |
| 453 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) | ||
| 454 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | ||
| 453 | 455 | ||
| 454 | if len(batch) != 0: | 456 | batch.append(example) |
| 455 | if self.fill_batch: | ||
| 456 | fill_items = self.bucket_items[self.bucket_assignments == bucket] | ||
| 457 | fill_perm = torch.randint(len(fill_items), (batch_size - len(batch),), generator=self.generator) | ||
| 458 | batch += fill_items[fill_perm] | ||
| 459 | yield batch | ||
