summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py53
1 files changed, 23 insertions, 30 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 02f83c6..775aea2 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--max_train_steps", 113 "--max_train_steps",
114 type=int, 114 type=int,
115 default=5000, 115 default=3000,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -150,7 +150,7 @@ def parse_args():
150 parser.add_argument( 150 parser.add_argument(
151 "--lr_warmup_steps", 151 "--lr_warmup_steps",
152 type=int, 152 type=int,
153 default=600, 153 default=500,
154 help="Number of steps for the warmup in the lr scheduler." 154 help="Number of steps for the warmup in the lr scheduler."
155 ) 155 )
156 parser.add_argument( 156 parser.add_argument(
@@ -167,7 +167,7 @@ def parse_args():
167 parser.add_argument( 167 parser.add_argument(
168 "--ema_power", 168 "--ema_power",
169 type=float, 169 type=float,
170 default=1.0 170 default=7 / 8
171 ) 171 )
172 parser.add_argument( 172 parser.add_argument(
173 "--ema_max_decay", 173 "--ema_max_decay",
@@ -468,20 +468,20 @@ def main():
468 if args.tokenizer_name: 468 if args.tokenizer_name:
469 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 469 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
470 elif args.pretrained_model_name_or_path: 470 elif args.pretrained_model_name_or_path:
471 tokenizer = CLIPTokenizer.from_pretrained( 471 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
472 args.pretrained_model_name_or_path + '/tokenizer'
473 )
474 472
475 # Load models and create wrapper for stable diffusion 473 # Load models and create wrapper for stable diffusion
476 text_encoder = CLIPTextModel.from_pretrained( 474 text_encoder = CLIPTextModel.from_pretrained(
477 args.pretrained_model_name_or_path + '/text_encoder', 475 args.pretrained_model_name_or_path, subfolder='text_encoder')
478 ) 476 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
479 vae = AutoencoderKL.from_pretrained( 477 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
480 args.pretrained_model_name_or_path + '/vae', 478
481 ) 479 ema_unet = EMAModel(
482 unet = UNet2DConditionModel.from_pretrained( 480 unet,
483 args.pretrained_model_name_or_path + '/unet', 481 inv_gamma=args.ema_inv_gamma,
484 ) 482 power=args.ema_power,
483 max_value=args.ema_max_decay
484 ) if args.use_ema else None
485 485
486 if args.gradient_checkpointing: 486 if args.gradient_checkpointing:
487 unet.enable_gradient_checkpointing() 487 unet.enable_gradient_checkpointing()
@@ -538,7 +538,7 @@ def main():
538 pixel_values += [example["class_images"] for example in examples] 538 pixel_values += [example["class_images"] for example in examples]
539 539
540 pixel_values = torch.stack(pixel_values) 540 pixel_values = torch.stack(pixel_values)
541 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) 541 pixel_values = pixel_values.to(memory_format=torch.contiguous_format)
542 542
543 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 543 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
544 544
@@ -629,16 +629,10 @@ def main():
629 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 629 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
630 ) 630 )
631 631
632 ema_unet = EMAModel(
633 unet,
634 inv_gamma=args.ema_inv_gamma,
635 power=args.ema_power,
636 max_value=args.ema_max_decay
637 ) if args.use_ema else None
638
639 # Move text_encoder and vae to device 632 # Move text_encoder and vae to device
640 text_encoder.to(accelerator.device) 633 text_encoder.to(accelerator.device)
641 vae.to(accelerator.device) 634 vae.to(accelerator.device)
635 ema_unet.averaged_model.to(accelerator.device)
642 636
643 # Keep text_encoder and vae in eval mode as we don't train these 637 # Keep text_encoder and vae in eval mode as we don't train these
644 text_encoder.eval() 638 text_encoder.eval()
@@ -698,7 +692,7 @@ def main():
698 disable=not accelerator.is_local_main_process, 692 disable=not accelerator.is_local_main_process,
699 dynamic_ncols=True 693 dynamic_ncols=True
700 ) 694 )
701 local_progress_bar.set_description("Batch X out of Y") 695 local_progress_bar.set_description("Epoch X / Y")
702 696
703 global_progress_bar = tqdm( 697 global_progress_bar = tqdm(
704 range(args.max_train_steps + val_steps), 698 range(args.max_train_steps + val_steps),
@@ -709,7 +703,7 @@ def main():
709 703
710 try: 704 try:
711 for epoch in range(num_epochs): 705 for epoch in range(num_epochs):
712 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") 706 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
713 local_progress_bar.reset() 707 local_progress_bar.reset()
714 708
715 unet.train() 709 unet.train()
@@ -720,9 +714,8 @@ def main():
720 for step, batch in enumerate(train_dataloader): 714 for step, batch in enumerate(train_dataloader):
721 with accelerator.accumulate(unet): 715 with accelerator.accumulate(unet):
722 # Convert images to latent space 716 # Convert images to latent space
723 with torch.no_grad(): 717 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
724 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 718 latents = latents * 0.18215
725 latents = latents * 0.18215
726 719
727 # Sample noise that we'll add to the latents 720 # Sample noise that we'll add to the latents
728 noise = torch.randn(latents.shape).to(latents.device) 721 noise = torch.randn(latents.shape).to(latents.device)
@@ -737,8 +730,7 @@ def main():
737 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 730 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
738 731
739 # Get the text embedding for conditioning 732 # Get the text embedding for conditioning
740 with torch.no_grad(): 733 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
741 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
742 734
743 # Predict the noise residual 735 # Predict the noise residual
744 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 736 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -840,7 +832,8 @@ def main():
840 global_progress_bar.clear() 832 global_progress_bar.clear()
841 833
842 if min_val_loss > val_loss: 834 if min_val_loss > val_loss:
843 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 835 accelerator.print(
836 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
844 min_val_loss = val_loss 837 min_val_loss = val_loss
845 838
846 if sample_checkpoint and accelerator.is_main_process: 839 if sample_checkpoint and accelerator.is_main_process: