summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-07 08:34:07 +0200
committerVolpeon <git@volpeon.ink>2022-10-07 08:34:07 +0200
commit14beba63391e1ddc9a145bb638d9306086ad1a5c (patch)
tree4e7d5126359c4ab6ab6dff3c2af537d659e276e8 /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.tar.gz
textual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.tar.bz2
textual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.zip
Training: Create multiple class images per training image
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py28
1 files changed, 22 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 0e69d79..24e6091 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -66,6 +66,12 @@ def parse_args():
66 help="A token to use as a placeholder for the concept.", 66 help="A token to use as a placeholder for the concept.",
67 ) 67 )
68 parser.add_argument( 68 parser.add_argument(
69 "--num_class_images",
70 type=int,
71 default=2,
72 help="How many class images to generate per training image."
73 )
74 parser.add_argument(
69 "--repeats", 75 "--repeats",
70 type=int, 76 type=int,
71 default=1, 77 default=1,
@@ -347,6 +353,7 @@ class Checkpointer:
347 scheduler=scheduler, 353 scheduler=scheduler,
348 ).to(self.accelerator.device) 354 ).to(self.accelerator.device)
349 pipeline.enable_attention_slicing() 355 pipeline.enable_attention_slicing()
356 pipeline.set_progress_bar_config(dynamic_ncols=True)
350 357
351 train_data = self.datamodule.train_dataloader() 358 train_data = self.datamodule.train_dataloader()
352 val_data = self.datamodule.val_dataloader() 359 val_data = self.datamodule.val_dataloader()
@@ -494,7 +501,7 @@ def main():
494 pixel_values = [example["instance_images"] for example in examples] 501 pixel_values = [example["instance_images"] for example in examples]
495 502
496 # concat class and instance examples for prior preservation 503 # concat class and instance examples for prior preservation
497 if args.class_identifier is not None and "class_prompt_ids" in examples[0]: 504 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
498 input_ids += [example["class_prompt_ids"] for example in examples] 505 input_ids += [example["class_prompt_ids"] for example in examples]
499 pixel_values += [example["class_images"] for example in examples] 506 pixel_values += [example["class_images"] for example in examples]
500 507
@@ -518,6 +525,7 @@ def main():
518 instance_identifier=args.instance_identifier, 525 instance_identifier=args.instance_identifier,
519 class_identifier=args.class_identifier, 526 class_identifier=args.class_identifier,
520 class_subdir="db_cls", 527 class_subdir="db_cls",
528 num_class_images=args.num_class_images,
521 size=args.resolution, 529 size=args.resolution,
522 repeats=args.repeats, 530 repeats=args.repeats,
523 center_crop=args.center_crop, 531 center_crop=args.center_crop,
@@ -528,7 +536,7 @@ def main():
528 datamodule.prepare_data() 536 datamodule.prepare_data()
529 datamodule.setup() 537 datamodule.setup()
530 538
531 if args.class_identifier is not None: 539 if args.num_class_images != 0:
532 missing_data = [item for item in datamodule.data if not item[1].exists()] 540 missing_data = [item for item in datamodule.data if not item[1].exists()]
533 541
534 if len(missing_data) != 0: 542 if len(missing_data) != 0:
@@ -547,6 +555,7 @@ def main():
547 scheduler=scheduler, 555 scheduler=scheduler,
548 ).to(accelerator.device) 556 ).to(accelerator.device)
549 pipeline.enable_attention_slicing() 557 pipeline.enable_attention_slicing()
558 pipeline.set_progress_bar_config(dynamic_ncols=True)
550 559
551 for batch in batched_data: 560 for batch in batched_data:
552 image_name = [p[1] for p in batch] 561 image_name = [p[1] for p in batch]
@@ -645,11 +654,18 @@ def main():
645 0, 654 0,
646 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 655 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
647 656
648 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), 657 local_progress_bar = tqdm(
649 disable=not accelerator.is_local_main_process) 658 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
659 disable=not accelerator.is_local_main_process,
660 dynamic_ncols=True
661 )
650 local_progress_bar.set_description("Batch X out of Y") 662 local_progress_bar.set_description("Batch X out of Y")
651 663
652 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) 664 global_progress_bar = tqdm(
665 range(args.max_train_steps + val_steps),
666 disable=not accelerator.is_local_main_process,
667 dynamic_ncols=True
668 )
653 global_progress_bar.set_description("Total progress") 669 global_progress_bar.set_description("Total progress")
654 670
655 try: 671 try:
@@ -686,7 +702,7 @@ def main():
686 # Predict the noise residual 702 # Predict the noise residual
687 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 703 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
688 704
689 if args.class_identifier is not None: 705 if args.num_class_images != 0:
690 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 706 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
691 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 707 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
692 noise, noise_prior = torch.chunk(noise, 2, dim=0) 708 noise, noise_prior = torch.chunk(noise, 2, dim=0)