diff options
-rw-r--r-- | dreambooth.py | 73 | ||||
-rw-r--r-- | environment.yaml | 1 |
2 files changed, 66 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py index 45a0497..c01cbe3 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -136,6 +136,11 @@ def parse_args(): | |||
136 | help="Number of steps for the warmup in the lr scheduler." | 136 | help="Number of steps for the warmup in the lr scheduler." |
137 | ) | 137 | ) |
138 | parser.add_argument( | 138 | parser.add_argument( |
139 | "--use_8bit_adam", | ||
140 | action="store_true", | ||
141 | help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
142 | ) | ||
143 | parser.add_argument( | ||
139 | "--adam_beta1", | 144 | "--adam_beta1", |
140 | type=float, | 145 | type=float, |
141 | default=0.9, | 146 | default=0.9, |
@@ -238,12 +243,24 @@ def parse_args(): | |||
238 | help="The prompt to specify images in the same class as provided intance images.", | 243 | help="The prompt to specify images in the same class as provided intance images.", |
239 | ) | 244 | ) |
240 | parser.add_argument( | 245 | parser.add_argument( |
246 | "--prior_loss_weight", | ||
247 | type=float, | ||
248 | default=1.0, | ||
249 | help="The weight of prior preservation loss." | ||
250 | ) | ||
251 | parser.add_argument( | ||
241 | "--with_prior_preservation", | 252 | "--with_prior_preservation", |
242 | default=False, | 253 | default=False, |
243 | action="store_true", | 254 | action="store_true", |
244 | help="Flag to add prior perservation loss.", | 255 | help="Flag to add prior perservation loss.", |
245 | ) | 256 | ) |
246 | parser.add_argument( | 257 | parser.add_argument( |
258 | "--max_grad_norm", | ||
259 | default=1.0, | ||
260 | type=float, | ||
261 | help="Max gradient norm." | ||
262 | ) | ||
263 | parser.add_argument( | ||
247 | "--num_class_images", | 264 | "--num_class_images", |
248 | type=int, | 265 | type=int, |
249 | default=100, | 266 | default=100, |
@@ -550,8 +567,19 @@ def main(): | |||
550 | args.train_batch_size * accelerator.num_processes | 567 | args.train_batch_size * accelerator.num_processes |
551 | ) | 568 | ) |
552 | 569 | ||
570 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | ||
571 | if args.use_8bit_adam: | ||
572 | try: | ||
573 | import bitsandbytes as bnb | ||
574 | except ImportError: | ||
575 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | ||
576 | |||
577 | optimizer_class = bnb.optim.AdamW8bit | ||
578 | else: | ||
579 | optimizer_class = torch.optim.AdamW | ||
580 | |||
553 | # Initialize the optimizer | 581 | # Initialize the optimizer |
554 | optimizer = torch.optim.AdamW( | 582 | optimizer = optimizer_class( |
555 | unet.parameters(), # only optimize unet | 583 | unet.parameters(), # only optimize unet |
556 | lr=args.learning_rate, | 584 | lr=args.learning_rate, |
557 | betas=(args.adam_beta1, args.adam_beta2), | 585 | betas=(args.adam_beta1, args.adam_beta2), |
@@ -705,8 +733,9 @@ def main(): | |||
705 | noise = torch.randn(latents.shape).to(latents.device) | 733 | noise = torch.randn(latents.shape).to(latents.device) |
706 | bsz = latents.shape[0] | 734 | bsz = latents.shape[0] |
707 | # Sample a random timestep for each image | 735 | # Sample a random timestep for each image |
708 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | 736 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
709 | (bsz,), device=latents.device).long() | 737 | (bsz,), device=latents.device) |
738 | timesteps = timesteps.long() | ||
710 | 739 | ||
711 | # Add noise to the latents according to the noise magnitude at each timestep | 740 | # Add noise to the latents according to the noise magnitude at each timestep |
712 | # (this is the forward diffusion process) | 741 | # (this is the forward diffusion process) |
@@ -719,14 +748,30 @@ def main(): | |||
719 | # Predict the noise residual | 748 | # Predict the noise residual |
720 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 749 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
721 | 750 | ||
722 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 751 | if args.with_prior_preservation: |
752 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | ||
753 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
754 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
755 | |||
756 | # Compute instance loss | ||
757 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
758 | |||
759 | # Compute prior loss | ||
760 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | ||
761 | reduction="none").mean([1, 2, 3]).mean() | ||
762 | |||
763 | # Add the prior loss to the instance loss. | ||
764 | loss = loss + args.prior_loss_weight * prior_loss | ||
765 | else: | ||
766 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
723 | 767 | ||
724 | accelerator.backward(loss) | 768 | accelerator.backward(loss) |
769 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | ||
725 | 770 | ||
726 | optimizer.step() | 771 | optimizer.step() |
727 | if not accelerator.optimizer_step_was_skipped: | 772 | if not accelerator.optimizer_step_was_skipped: |
728 | lr_scheduler.step() | 773 | lr_scheduler.step() |
729 | optimizer.zero_grad(set_to_none=True) | 774 | optimizer.zero_grad() |
730 | 775 | ||
731 | loss = loss.detach().item() | 776 | loss = loss.detach().item() |
732 | train_loss += loss | 777 | train_loss += loss |
@@ -765,8 +810,9 @@ def main(): | |||
765 | 810 | ||
766 | noise = torch.randn(latents.shape).to(latents.device) | 811 | noise = torch.randn(latents.shape).to(latents.device) |
767 | bsz = latents.shape[0] | 812 | bsz = latents.shape[0] |
768 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | 813 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
769 | (bsz,), device=latents.device).long() | 814 | (bsz,), device=latents.device) |
815 | timesteps = timesteps.long() | ||
770 | 816 | ||
771 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 817 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
772 | 818 | ||
@@ -776,7 +822,18 @@ def main(): | |||
776 | 822 | ||
777 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 823 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
778 | 824 | ||
779 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 825 | if args.with_prior_preservation: |
826 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
827 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
828 | |||
829 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
830 | |||
831 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | ||
832 | reduction="none").mean([1, 2, 3]).mean() | ||
833 | |||
834 | loss = loss + args.prior_loss_weight * prior_loss | ||
835 | else: | ||
836 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
780 | 837 | ||
781 | loss = loss.detach().item() | 838 | loss = loss.detach().item() |
782 | val_loss += loss | 839 | val_loss += loss |
diff --git a/environment.yaml b/environment.yaml index 46a4388..c9f498e 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -19,6 +19,7 @@ dependencies: | |||
19 | - -e git+https://github.com/ShivamShrirao/diffusers#egg=diffusers | 19 | - -e git+https://github.com/ShivamShrirao/diffusers#egg=diffusers |
20 | - accelerate==0.12.0 | 20 | - accelerate==0.12.0 |
21 | - albumentations==1.1.0 | 21 | - albumentations==1.1.0 |
22 | - bitsandbytes==0.34.0 | ||
22 | - einops==0.4.1 | 23 | - einops==0.4.1 |
23 | - imageio==2.22.0 | 24 | - imageio==2.22.0 |
24 | - kornia==0.6 | 25 | - kornia==0.6 |