diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index 0e69d79..24e6091 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -66,6 +66,12 @@ def parse_args(): | |||
66 | help="A token to use as a placeholder for the concept.", | 66 | help="A token to use as a placeholder for the concept.", |
67 | ) | 67 | ) |
68 | parser.add_argument( | 68 | parser.add_argument( |
69 | "--num_class_images", | ||
70 | type=int, | ||
71 | default=2, | ||
72 | help="How many class images to generate per training image." | ||
73 | ) | ||
74 | parser.add_argument( | ||
69 | "--repeats", | 75 | "--repeats", |
70 | type=int, | 76 | type=int, |
71 | default=1, | 77 | default=1, |
@@ -347,6 +353,7 @@ class Checkpointer: | |||
347 | scheduler=scheduler, | 353 | scheduler=scheduler, |
348 | ).to(self.accelerator.device) | 354 | ).to(self.accelerator.device) |
349 | pipeline.enable_attention_slicing() | 355 | pipeline.enable_attention_slicing() |
356 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
350 | 357 | ||
351 | train_data = self.datamodule.train_dataloader() | 358 | train_data = self.datamodule.train_dataloader() |
352 | val_data = self.datamodule.val_dataloader() | 359 | val_data = self.datamodule.val_dataloader() |
@@ -494,7 +501,7 @@ def main(): | |||
494 | pixel_values = [example["instance_images"] for example in examples] | 501 | pixel_values = [example["instance_images"] for example in examples] |
495 | 502 | ||
496 | # concat class and instance examples for prior preservation | 503 | # concat class and instance examples for prior preservation |
497 | if args.class_identifier is not None and "class_prompt_ids" in examples[0]: | 504 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: |
498 | input_ids += [example["class_prompt_ids"] for example in examples] | 505 | input_ids += [example["class_prompt_ids"] for example in examples] |
499 | pixel_values += [example["class_images"] for example in examples] | 506 | pixel_values += [example["class_images"] for example in examples] |
500 | 507 | ||
@@ -518,6 +525,7 @@ def main(): | |||
518 | instance_identifier=args.instance_identifier, | 525 | instance_identifier=args.instance_identifier, |
519 | class_identifier=args.class_identifier, | 526 | class_identifier=args.class_identifier, |
520 | class_subdir="db_cls", | 527 | class_subdir="db_cls", |
528 | num_class_images=args.num_class_images, | ||
521 | size=args.resolution, | 529 | size=args.resolution, |
522 | repeats=args.repeats, | 530 | repeats=args.repeats, |
523 | center_crop=args.center_crop, | 531 | center_crop=args.center_crop, |
@@ -528,7 +536,7 @@ def main(): | |||
528 | datamodule.prepare_data() | 536 | datamodule.prepare_data() |
529 | datamodule.setup() | 537 | datamodule.setup() |
530 | 538 | ||
531 | if args.class_identifier is not None: | 539 | if args.num_class_images != 0: |
532 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 540 | missing_data = [item for item in datamodule.data if not item[1].exists()] |
533 | 541 | ||
534 | if len(missing_data) != 0: | 542 | if len(missing_data) != 0: |
@@ -547,6 +555,7 @@ def main(): | |||
547 | scheduler=scheduler, | 555 | scheduler=scheduler, |
548 | ).to(accelerator.device) | 556 | ).to(accelerator.device) |
549 | pipeline.enable_attention_slicing() | 557 | pipeline.enable_attention_slicing() |
558 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
550 | 559 | ||
551 | for batch in batched_data: | 560 | for batch in batched_data: |
552 | image_name = [p[1] for p in batch] | 561 | image_name = [p[1] for p in batch] |
@@ -645,11 +654,18 @@ def main(): | |||
645 | 0, | 654 | 0, |
646 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 655 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
647 | 656 | ||
648 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 657 | local_progress_bar = tqdm( |
649 | disable=not accelerator.is_local_main_process) | 658 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
659 | disable=not accelerator.is_local_main_process, | ||
660 | dynamic_ncols=True | ||
661 | ) | ||
650 | local_progress_bar.set_description("Batch X out of Y") | 662 | local_progress_bar.set_description("Batch X out of Y") |
651 | 663 | ||
652 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) | 664 | global_progress_bar = tqdm( |
665 | range(args.max_train_steps + val_steps), | ||
666 | disable=not accelerator.is_local_main_process, | ||
667 | dynamic_ncols=True | ||
668 | ) | ||
653 | global_progress_bar.set_description("Total progress") | 669 | global_progress_bar.set_description("Total progress") |
654 | 670 | ||
655 | try: | 671 | try: |
@@ -686,7 +702,7 @@ def main(): | |||
686 | # Predict the noise residual | 702 | # Predict the noise residual |
687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 703 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
688 | 704 | ||
689 | if args.class_identifier is not None: | 705 | if args.num_class_images != 0: |
690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 706 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 707 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 708 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |