diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-07 08:34:07 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-07 08:34:07 +0200 |
| commit | 14beba63391e1ddc9a145bb638d9306086ad1a5c (patch) | |
| tree | 4e7d5126359c4ab6ab6dff3c2af537d659e276e8 /dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.tar.gz textual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.tar.bz2 textual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.zip | |
Training: Create multiple class images per training image
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index 0e69d79..24e6091 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -66,6 +66,12 @@ def parse_args(): | |||
| 66 | help="A token to use as a placeholder for the concept.", | 66 | help="A token to use as a placeholder for the concept.", |
| 67 | ) | 67 | ) |
| 68 | parser.add_argument( | 68 | parser.add_argument( |
| 69 | "--num_class_images", | ||
| 70 | type=int, | ||
| 71 | default=2, | ||
| 72 | help="How many class images to generate per training image." | ||
| 73 | ) | ||
| 74 | parser.add_argument( | ||
| 69 | "--repeats", | 75 | "--repeats", |
| 70 | type=int, | 76 | type=int, |
| 71 | default=1, | 77 | default=1, |
| @@ -347,6 +353,7 @@ class Checkpointer: | |||
| 347 | scheduler=scheduler, | 353 | scheduler=scheduler, |
| 348 | ).to(self.accelerator.device) | 354 | ).to(self.accelerator.device) |
| 349 | pipeline.enable_attention_slicing() | 355 | pipeline.enable_attention_slicing() |
| 356 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 350 | 357 | ||
| 351 | train_data = self.datamodule.train_dataloader() | 358 | train_data = self.datamodule.train_dataloader() |
| 352 | val_data = self.datamodule.val_dataloader() | 359 | val_data = self.datamodule.val_dataloader() |
| @@ -494,7 +501,7 @@ def main(): | |||
| 494 | pixel_values = [example["instance_images"] for example in examples] | 501 | pixel_values = [example["instance_images"] for example in examples] |
| 495 | 502 | ||
| 496 | # concat class and instance examples for prior preservation | 503 | # concat class and instance examples for prior preservation |
| 497 | if args.class_identifier is not None and "class_prompt_ids" in examples[0]: | 504 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: |
| 498 | input_ids += [example["class_prompt_ids"] for example in examples] | 505 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 499 | pixel_values += [example["class_images"] for example in examples] | 506 | pixel_values += [example["class_images"] for example in examples] |
| 500 | 507 | ||
| @@ -518,6 +525,7 @@ def main(): | |||
| 518 | instance_identifier=args.instance_identifier, | 525 | instance_identifier=args.instance_identifier, |
| 519 | class_identifier=args.class_identifier, | 526 | class_identifier=args.class_identifier, |
| 520 | class_subdir="db_cls", | 527 | class_subdir="db_cls", |
| 528 | num_class_images=args.num_class_images, | ||
| 521 | size=args.resolution, | 529 | size=args.resolution, |
| 522 | repeats=args.repeats, | 530 | repeats=args.repeats, |
| 523 | center_crop=args.center_crop, | 531 | center_crop=args.center_crop, |
| @@ -528,7 +536,7 @@ def main(): | |||
| 528 | datamodule.prepare_data() | 536 | datamodule.prepare_data() |
| 529 | datamodule.setup() | 537 | datamodule.setup() |
| 530 | 538 | ||
| 531 | if args.class_identifier is not None: | 539 | if args.num_class_images != 0: |
| 532 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 540 | missing_data = [item for item in datamodule.data if not item[1].exists()] |
| 533 | 541 | ||
| 534 | if len(missing_data) != 0: | 542 | if len(missing_data) != 0: |
| @@ -547,6 +555,7 @@ def main(): | |||
| 547 | scheduler=scheduler, | 555 | scheduler=scheduler, |
| 548 | ).to(accelerator.device) | 556 | ).to(accelerator.device) |
| 549 | pipeline.enable_attention_slicing() | 557 | pipeline.enable_attention_slicing() |
| 558 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 550 | 559 | ||
| 551 | for batch in batched_data: | 560 | for batch in batched_data: |
| 552 | image_name = [p[1] for p in batch] | 561 | image_name = [p[1] for p in batch] |
| @@ -645,11 +654,18 @@ def main(): | |||
| 645 | 0, | 654 | 0, |
| 646 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 655 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 647 | 656 | ||
| 648 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 657 | local_progress_bar = tqdm( |
| 649 | disable=not accelerator.is_local_main_process) | 658 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| 659 | disable=not accelerator.is_local_main_process, | ||
| 660 | dynamic_ncols=True | ||
| 661 | ) | ||
| 650 | local_progress_bar.set_description("Batch X out of Y") | 662 | local_progress_bar.set_description("Batch X out of Y") |
| 651 | 663 | ||
| 652 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) | 664 | global_progress_bar = tqdm( |
| 665 | range(args.max_train_steps + val_steps), | ||
| 666 | disable=not accelerator.is_local_main_process, | ||
| 667 | dynamic_ncols=True | ||
| 668 | ) | ||
| 653 | global_progress_bar.set_description("Total progress") | 669 | global_progress_bar.set_description("Total progress") |
| 654 | 670 | ||
| 655 | try: | 671 | try: |
| @@ -686,7 +702,7 @@ def main(): | |||
| 686 | # Predict the noise residual | 702 | # Predict the noise residual |
| 687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 703 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 688 | 704 | ||
| 689 | if args.class_identifier is not None: | 705 | if args.num_class_images != 0: |
| 690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 706 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 707 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| 692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 708 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
