diff options
| -rw-r--r-- | dreambooth.py | 73 | ||||
| -rw-r--r-- | environment.yaml | 1 |
2 files changed, 66 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py index 45a0497..c01cbe3 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -136,6 +136,11 @@ def parse_args(): | |||
| 136 | help="Number of steps for the warmup in the lr scheduler." | 136 | help="Number of steps for the warmup in the lr scheduler." |
| 137 | ) | 137 | ) |
| 138 | parser.add_argument( | 138 | parser.add_argument( |
| 139 | "--use_8bit_adam", | ||
| 140 | action="store_true", | ||
| 141 | help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
| 142 | ) | ||
| 143 | parser.add_argument( | ||
| 139 | "--adam_beta1", | 144 | "--adam_beta1", |
| 140 | type=float, | 145 | type=float, |
| 141 | default=0.9, | 146 | default=0.9, |
| @@ -238,12 +243,24 @@ def parse_args(): | |||
| 238 | help="The prompt to specify images in the same class as provided intance images.", | 243 | help="The prompt to specify images in the same class as provided intance images.", |
| 239 | ) | 244 | ) |
| 240 | parser.add_argument( | 245 | parser.add_argument( |
| 246 | "--prior_loss_weight", | ||
| 247 | type=float, | ||
| 248 | default=1.0, | ||
| 249 | help="The weight of prior preservation loss." | ||
| 250 | ) | ||
| 251 | parser.add_argument( | ||
| 241 | "--with_prior_preservation", | 252 | "--with_prior_preservation", |
| 242 | default=False, | 253 | default=False, |
| 243 | action="store_true", | 254 | action="store_true", |
| 244 | help="Flag to add prior perservation loss.", | 255 | help="Flag to add prior perservation loss.", |
| 245 | ) | 256 | ) |
| 246 | parser.add_argument( | 257 | parser.add_argument( |
| 258 | "--max_grad_norm", | ||
| 259 | default=1.0, | ||
| 260 | type=float, | ||
| 261 | help="Max gradient norm." | ||
| 262 | ) | ||
| 263 | parser.add_argument( | ||
| 247 | "--num_class_images", | 264 | "--num_class_images", |
| 248 | type=int, | 265 | type=int, |
| 249 | default=100, | 266 | default=100, |
| @@ -550,8 +567,19 @@ def main(): | |||
| 550 | args.train_batch_size * accelerator.num_processes | 567 | args.train_batch_size * accelerator.num_processes |
| 551 | ) | 568 | ) |
| 552 | 569 | ||
| 570 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | ||
| 571 | if args.use_8bit_adam: | ||
| 572 | try: | ||
| 573 | import bitsandbytes as bnb | ||
| 574 | except ImportError: | ||
| 575 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | ||
| 576 | |||
| 577 | optimizer_class = bnb.optim.AdamW8bit | ||
| 578 | else: | ||
| 579 | optimizer_class = torch.optim.AdamW | ||
| 580 | |||
| 553 | # Initialize the optimizer | 581 | # Initialize the optimizer |
| 554 | optimizer = torch.optim.AdamW( | 582 | optimizer = optimizer_class( |
| 555 | unet.parameters(), # only optimize unet | 583 | unet.parameters(), # only optimize unet |
| 556 | lr=args.learning_rate, | 584 | lr=args.learning_rate, |
| 557 | betas=(args.adam_beta1, args.adam_beta2), | 585 | betas=(args.adam_beta1, args.adam_beta2), |
| @@ -705,8 +733,9 @@ def main(): | |||
| 705 | noise = torch.randn(latents.shape).to(latents.device) | 733 | noise = torch.randn(latents.shape).to(latents.device) |
| 706 | bsz = latents.shape[0] | 734 | bsz = latents.shape[0] |
| 707 | # Sample a random timestep for each image | 735 | # Sample a random timestep for each image |
| 708 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | 736 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 709 | (bsz,), device=latents.device).long() | 737 | (bsz,), device=latents.device) |
| 738 | timesteps = timesteps.long() | ||
| 710 | 739 | ||
| 711 | # Add noise to the latents according to the noise magnitude at each timestep | 740 | # Add noise to the latents according to the noise magnitude at each timestep |
| 712 | # (this is the forward diffusion process) | 741 | # (this is the forward diffusion process) |
| @@ -719,14 +748,30 @@ def main(): | |||
| 719 | # Predict the noise residual | 748 | # Predict the noise residual |
| 720 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 749 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 721 | 750 | ||
| 722 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 751 | if args.with_prior_preservation: |
| 752 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | ||
| 753 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
| 754 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 755 | |||
| 756 | # Compute instance loss | ||
| 757 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 758 | |||
| 759 | # Compute prior loss | ||
| 760 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | ||
| 761 | reduction="none").mean([1, 2, 3]).mean() | ||
| 762 | |||
| 763 | # Add the prior loss to the instance loss. | ||
| 764 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 765 | else: | ||
| 766 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 723 | 767 | ||
| 724 | accelerator.backward(loss) | 768 | accelerator.backward(loss) |
| 769 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | ||
| 725 | 770 | ||
| 726 | optimizer.step() | 771 | optimizer.step() |
| 727 | if not accelerator.optimizer_step_was_skipped: | 772 | if not accelerator.optimizer_step_was_skipped: |
| 728 | lr_scheduler.step() | 773 | lr_scheduler.step() |
| 729 | optimizer.zero_grad(set_to_none=True) | 774 | optimizer.zero_grad() |
| 730 | 775 | ||
| 731 | loss = loss.detach().item() | 776 | loss = loss.detach().item() |
| 732 | train_loss += loss | 777 | train_loss += loss |
| @@ -765,8 +810,9 @@ def main(): | |||
| 765 | 810 | ||
| 766 | noise = torch.randn(latents.shape).to(latents.device) | 811 | noise = torch.randn(latents.shape).to(latents.device) |
| 767 | bsz = latents.shape[0] | 812 | bsz = latents.shape[0] |
| 768 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | 813 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 769 | (bsz,), device=latents.device).long() | 814 | (bsz,), device=latents.device) |
| 815 | timesteps = timesteps.long() | ||
| 770 | 816 | ||
| 771 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 817 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 772 | 818 | ||
| @@ -776,7 +822,18 @@ def main(): | |||
| 776 | 822 | ||
| 777 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 823 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 778 | 824 | ||
| 779 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 825 | if args.with_prior_preservation: |
| 826 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
| 827 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 828 | |||
| 829 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 830 | |||
| 831 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | ||
| 832 | reduction="none").mean([1, 2, 3]).mean() | ||
| 833 | |||
| 834 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 835 | else: | ||
| 836 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 780 | 837 | ||
| 781 | loss = loss.detach().item() | 838 | loss = loss.detach().item() |
| 782 | val_loss += loss | 839 | val_loss += loss |
diff --git a/environment.yaml b/environment.yaml index 46a4388..c9f498e 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -19,6 +19,7 @@ dependencies: | |||
| 19 | - -e git+https://github.com/ShivamShrirao/diffusers#egg=diffusers | 19 | - -e git+https://github.com/ShivamShrirao/diffusers#egg=diffusers |
| 20 | - accelerate==0.12.0 | 20 | - accelerate==0.12.0 |
| 21 | - albumentations==1.1.0 | 21 | - albumentations==1.1.0 |
| 22 | - bitsandbytes==0.34.0 | ||
| 22 | - einops==0.4.1 | 23 | - einops==0.4.1 |
| 23 | - imageio==2.22.0 | 24 | - imageio==2.22.0 |
| 24 | - kornia==0.6 | 25 | - kornia==0.6 |
