diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 53 |
1 files changed, 23 insertions, 30 deletions
diff --git a/dreambooth.py b/dreambooth.py index 02f83c6..775aea2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -112,7 +112,7 @@ def parse_args(): | |||
112 | parser.add_argument( | 112 | parser.add_argument( |
113 | "--max_train_steps", | 113 | "--max_train_steps", |
114 | type=int, | 114 | type=int, |
115 | default=5000, | 115 | default=3000, |
116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
117 | ) | 117 | ) |
118 | parser.add_argument( | 118 | parser.add_argument( |
@@ -150,7 +150,7 @@ def parse_args(): | |||
150 | parser.add_argument( | 150 | parser.add_argument( |
151 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
152 | type=int, | 152 | type=int, |
153 | default=600, | 153 | default=500, |
154 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
155 | ) | 155 | ) |
156 | parser.add_argument( | 156 | parser.add_argument( |
@@ -167,7 +167,7 @@ def parse_args(): | |||
167 | parser.add_argument( | 167 | parser.add_argument( |
168 | "--ema_power", | 168 | "--ema_power", |
169 | type=float, | 169 | type=float, |
170 | default=1.0 | 170 | default=7 / 8 |
171 | ) | 171 | ) |
172 | parser.add_argument( | 172 | parser.add_argument( |
173 | "--ema_max_decay", | 173 | "--ema_max_decay", |
@@ -468,20 +468,20 @@ def main(): | |||
468 | if args.tokenizer_name: | 468 | if args.tokenizer_name: |
469 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 469 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
470 | elif args.pretrained_model_name_or_path: | 470 | elif args.pretrained_model_name_or_path: |
471 | tokenizer = CLIPTokenizer.from_pretrained( | 471 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
472 | args.pretrained_model_name_or_path + '/tokenizer' | ||
473 | ) | ||
474 | 472 | ||
475 | # Load models and create wrapper for stable diffusion | 473 | # Load models and create wrapper for stable diffusion |
476 | text_encoder = CLIPTextModel.from_pretrained( | 474 | text_encoder = CLIPTextModel.from_pretrained( |
477 | args.pretrained_model_name_or_path + '/text_encoder', | 475 | args.pretrained_model_name_or_path, subfolder='text_encoder') |
478 | ) | 476 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
479 | vae = AutoencoderKL.from_pretrained( | 477 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
480 | args.pretrained_model_name_or_path + '/vae', | 478 | |
481 | ) | 479 | ema_unet = EMAModel( |
482 | unet = UNet2DConditionModel.from_pretrained( | 480 | unet, |
483 | args.pretrained_model_name_or_path + '/unet', | 481 | inv_gamma=args.ema_inv_gamma, |
484 | ) | 482 | power=args.ema_power, |
483 | max_value=args.ema_max_decay | ||
484 | ) if args.use_ema else None | ||
485 | 485 | ||
486 | if args.gradient_checkpointing: | 486 | if args.gradient_checkpointing: |
487 | unet.enable_gradient_checkpointing() | 487 | unet.enable_gradient_checkpointing() |
@@ -538,7 +538,7 @@ def main(): | |||
538 | pixel_values += [example["class_images"] for example in examples] | 538 | pixel_values += [example["class_images"] for example in examples] |
539 | 539 | ||
540 | pixel_values = torch.stack(pixel_values) | 540 | pixel_values = torch.stack(pixel_values) |
541 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 541 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format) |
542 | 542 | ||
543 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 543 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
544 | 544 | ||
@@ -629,16 +629,10 @@ def main(): | |||
629 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 629 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
630 | ) | 630 | ) |
631 | 631 | ||
632 | ema_unet = EMAModel( | ||
633 | unet, | ||
634 | inv_gamma=args.ema_inv_gamma, | ||
635 | power=args.ema_power, | ||
636 | max_value=args.ema_max_decay | ||
637 | ) if args.use_ema else None | ||
638 | |||
639 | # Move text_encoder and vae to device | 632 | # Move text_encoder and vae to device |
640 | text_encoder.to(accelerator.device) | 633 | text_encoder.to(accelerator.device) |
641 | vae.to(accelerator.device) | 634 | vae.to(accelerator.device) |
635 | ema_unet.averaged_model.to(accelerator.device) | ||
642 | 636 | ||
643 | # Keep text_encoder and vae in eval mode as we don't train these | 637 | # Keep text_encoder and vae in eval mode as we don't train these |
644 | text_encoder.eval() | 638 | text_encoder.eval() |
@@ -698,7 +692,7 @@ def main(): | |||
698 | disable=not accelerator.is_local_main_process, | 692 | disable=not accelerator.is_local_main_process, |
699 | dynamic_ncols=True | 693 | dynamic_ncols=True |
700 | ) | 694 | ) |
701 | local_progress_bar.set_description("Batch X out of Y") | 695 | local_progress_bar.set_description("Epoch X / Y") |
702 | 696 | ||
703 | global_progress_bar = tqdm( | 697 | global_progress_bar = tqdm( |
704 | range(args.max_train_steps + val_steps), | 698 | range(args.max_train_steps + val_steps), |
@@ -709,7 +703,7 @@ def main(): | |||
709 | 703 | ||
710 | try: | 704 | try: |
711 | for epoch in range(num_epochs): | 705 | for epoch in range(num_epochs): |
712 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | 706 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
713 | local_progress_bar.reset() | 707 | local_progress_bar.reset() |
714 | 708 | ||
715 | unet.train() | 709 | unet.train() |
@@ -720,9 +714,8 @@ def main(): | |||
720 | for step, batch in enumerate(train_dataloader): | 714 | for step, batch in enumerate(train_dataloader): |
721 | with accelerator.accumulate(unet): | 715 | with accelerator.accumulate(unet): |
722 | # Convert images to latent space | 716 | # Convert images to latent space |
723 | with torch.no_grad(): | 717 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
724 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 718 | latents = latents * 0.18215 |
725 | latents = latents * 0.18215 | ||
726 | 719 | ||
727 | # Sample noise that we'll add to the latents | 720 | # Sample noise that we'll add to the latents |
728 | noise = torch.randn(latents.shape).to(latents.device) | 721 | noise = torch.randn(latents.shape).to(latents.device) |
@@ -737,8 +730,7 @@ def main(): | |||
737 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 730 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
738 | 731 | ||
739 | # Get the text embedding for conditioning | 732 | # Get the text embedding for conditioning |
740 | with torch.no_grad(): | 733 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
741 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | ||
742 | 734 | ||
743 | # Predict the noise residual | 735 | # Predict the noise residual |
744 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 736 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -840,7 +832,8 @@ def main(): | |||
840 | global_progress_bar.clear() | 832 | global_progress_bar.clear() |
841 | 833 | ||
842 | if min_val_loss > val_loss: | 834 | if min_val_loss > val_loss: |
843 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 835 | accelerator.print( |
836 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
844 | min_val_loss = val_loss | 837 | min_val_loss = val_loss |
845 | 838 | ||
846 | if sample_checkpoint and accelerator.is_main_process: | 839 | if sample_checkpoint and accelerator.is_main_process: |