diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 53 |
1 files changed, 23 insertions, 30 deletions
diff --git a/dreambooth.py b/dreambooth.py index 02f83c6..775aea2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -112,7 +112,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 112 | parser.add_argument( |
| 113 | "--max_train_steps", | 113 | "--max_train_steps", |
| 114 | type=int, | 114 | type=int, |
| 115 | default=5000, | 115 | default=3000, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 117 | ) |
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| @@ -150,7 +150,7 @@ def parse_args(): | |||
| 150 | parser.add_argument( | 150 | parser.add_argument( |
| 151 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
| 152 | type=int, | 152 | type=int, |
| 153 | default=600, | 153 | default=500, |
| 154 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
| 155 | ) | 155 | ) |
| 156 | parser.add_argument( | 156 | parser.add_argument( |
| @@ -167,7 +167,7 @@ def parse_args(): | |||
| 167 | parser.add_argument( | 167 | parser.add_argument( |
| 168 | "--ema_power", | 168 | "--ema_power", |
| 169 | type=float, | 169 | type=float, |
| 170 | default=1.0 | 170 | default=7 / 8 |
| 171 | ) | 171 | ) |
| 172 | parser.add_argument( | 172 | parser.add_argument( |
| 173 | "--ema_max_decay", | 173 | "--ema_max_decay", |
| @@ -468,20 +468,20 @@ def main(): | |||
| 468 | if args.tokenizer_name: | 468 | if args.tokenizer_name: |
| 469 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 469 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 470 | elif args.pretrained_model_name_or_path: | 470 | elif args.pretrained_model_name_or_path: |
| 471 | tokenizer = CLIPTokenizer.from_pretrained( | 471 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 472 | args.pretrained_model_name_or_path + '/tokenizer' | ||
| 473 | ) | ||
| 474 | 472 | ||
| 475 | # Load models and create wrapper for stable diffusion | 473 | # Load models and create wrapper for stable diffusion |
| 476 | text_encoder = CLIPTextModel.from_pretrained( | 474 | text_encoder = CLIPTextModel.from_pretrained( |
| 477 | args.pretrained_model_name_or_path + '/text_encoder', | 475 | args.pretrained_model_name_or_path, subfolder='text_encoder') |
| 478 | ) | 476 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 479 | vae = AutoencoderKL.from_pretrained( | 477 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
| 480 | args.pretrained_model_name_or_path + '/vae', | 478 | |
| 481 | ) | 479 | ema_unet = EMAModel( |
| 482 | unet = UNet2DConditionModel.from_pretrained( | 480 | unet, |
| 483 | args.pretrained_model_name_or_path + '/unet', | 481 | inv_gamma=args.ema_inv_gamma, |
| 484 | ) | 482 | power=args.ema_power, |
| 483 | max_value=args.ema_max_decay | ||
| 484 | ) if args.use_ema else None | ||
| 485 | 485 | ||
| 486 | if args.gradient_checkpointing: | 486 | if args.gradient_checkpointing: |
| 487 | unet.enable_gradient_checkpointing() | 487 | unet.enable_gradient_checkpointing() |
| @@ -538,7 +538,7 @@ def main(): | |||
| 538 | pixel_values += [example["class_images"] for example in examples] | 538 | pixel_values += [example["class_images"] for example in examples] |
| 539 | 539 | ||
| 540 | pixel_values = torch.stack(pixel_values) | 540 | pixel_values = torch.stack(pixel_values) |
| 541 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 541 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format) |
| 542 | 542 | ||
| 543 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 543 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
| 544 | 544 | ||
| @@ -629,16 +629,10 @@ def main(): | |||
| 629 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 629 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 630 | ) | 630 | ) |
| 631 | 631 | ||
| 632 | ema_unet = EMAModel( | ||
| 633 | unet, | ||
| 634 | inv_gamma=args.ema_inv_gamma, | ||
| 635 | power=args.ema_power, | ||
| 636 | max_value=args.ema_max_decay | ||
| 637 | ) if args.use_ema else None | ||
| 638 | |||
| 639 | # Move text_encoder and vae to device | 632 | # Move text_encoder and vae to device |
| 640 | text_encoder.to(accelerator.device) | 633 | text_encoder.to(accelerator.device) |
| 641 | vae.to(accelerator.device) | 634 | vae.to(accelerator.device) |
| 635 | ema_unet.averaged_model.to(accelerator.device) | ||
| 642 | 636 | ||
| 643 | # Keep text_encoder and vae in eval mode as we don't train these | 637 | # Keep text_encoder and vae in eval mode as we don't train these |
| 644 | text_encoder.eval() | 638 | text_encoder.eval() |
| @@ -698,7 +692,7 @@ def main(): | |||
| 698 | disable=not accelerator.is_local_main_process, | 692 | disable=not accelerator.is_local_main_process, |
| 699 | dynamic_ncols=True | 693 | dynamic_ncols=True |
| 700 | ) | 694 | ) |
| 701 | local_progress_bar.set_description("Batch X out of Y") | 695 | local_progress_bar.set_description("Epoch X / Y") |
| 702 | 696 | ||
| 703 | global_progress_bar = tqdm( | 697 | global_progress_bar = tqdm( |
| 704 | range(args.max_train_steps + val_steps), | 698 | range(args.max_train_steps + val_steps), |
| @@ -709,7 +703,7 @@ def main(): | |||
| 709 | 703 | ||
| 710 | try: | 704 | try: |
| 711 | for epoch in range(num_epochs): | 705 | for epoch in range(num_epochs): |
| 712 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | 706 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 713 | local_progress_bar.reset() | 707 | local_progress_bar.reset() |
| 714 | 708 | ||
| 715 | unet.train() | 709 | unet.train() |
| @@ -720,9 +714,8 @@ def main(): | |||
| 720 | for step, batch in enumerate(train_dataloader): | 714 | for step, batch in enumerate(train_dataloader): |
| 721 | with accelerator.accumulate(unet): | 715 | with accelerator.accumulate(unet): |
| 722 | # Convert images to latent space | 716 | # Convert images to latent space |
| 723 | with torch.no_grad(): | 717 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 724 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 718 | latents = latents * 0.18215 |
| 725 | latents = latents * 0.18215 | ||
| 726 | 719 | ||
| 727 | # Sample noise that we'll add to the latents | 720 | # Sample noise that we'll add to the latents |
| 728 | noise = torch.randn(latents.shape).to(latents.device) | 721 | noise = torch.randn(latents.shape).to(latents.device) |
| @@ -737,8 +730,7 @@ def main(): | |||
| 737 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 730 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 738 | 731 | ||
| 739 | # Get the text embedding for conditioning | 732 | # Get the text embedding for conditioning |
| 740 | with torch.no_grad(): | 733 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
| 741 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | ||
| 742 | 734 | ||
| 743 | # Predict the noise residual | 735 | # Predict the noise residual |
| 744 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 736 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -840,7 +832,8 @@ def main(): | |||
| 840 | global_progress_bar.clear() | 832 | global_progress_bar.clear() |
| 841 | 833 | ||
| 842 | if min_val_loss > val_loss: | 834 | if min_val_loss > val_loss: |
| 843 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 835 | accelerator.print( |
| 836 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
| 844 | min_val_loss = val_loss | 837 | min_val_loss = val_loss |
| 845 | 838 | ||
| 846 | if sample_checkpoint and accelerator.is_main_process: | 839 | if sample_checkpoint and accelerator.is_main_process: |
