summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py73
-rw-r--r--environment.yaml1
2 files changed, 66 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 45a0497..c01cbe3 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -136,6 +136,11 @@ def parse_args():
136 help="Number of steps for the warmup in the lr scheduler." 136 help="Number of steps for the warmup in the lr scheduler."
137 ) 137 )
138 parser.add_argument( 138 parser.add_argument(
139 "--use_8bit_adam",
140 action="store_true",
141 help="Whether or not to use 8-bit Adam from bitsandbytes."
142 )
143 parser.add_argument(
139 "--adam_beta1", 144 "--adam_beta1",
140 type=float, 145 type=float,
141 default=0.9, 146 default=0.9,
@@ -238,12 +243,24 @@ def parse_args():
238 help="The prompt to specify images in the same class as provided intance images.", 243 help="The prompt to specify images in the same class as provided intance images.",
239 ) 244 )
240 parser.add_argument( 245 parser.add_argument(
246 "--prior_loss_weight",
247 type=float,
248 default=1.0,
249 help="The weight of prior preservation loss."
250 )
251 parser.add_argument(
241 "--with_prior_preservation", 252 "--with_prior_preservation",
242 default=False, 253 default=False,
243 action="store_true", 254 action="store_true",
244 help="Flag to add prior perservation loss.", 255 help="Flag to add prior perservation loss.",
245 ) 256 )
246 parser.add_argument( 257 parser.add_argument(
258 "--max_grad_norm",
259 default=1.0,
260 type=float,
261 help="Max gradient norm."
262 )
263 parser.add_argument(
247 "--num_class_images", 264 "--num_class_images",
248 type=int, 265 type=int,
249 default=100, 266 default=100,
@@ -550,8 +567,19 @@ def main():
550 args.train_batch_size * accelerator.num_processes 567 args.train_batch_size * accelerator.num_processes
551 ) 568 )
552 569
570 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
571 if args.use_8bit_adam:
572 try:
573 import bitsandbytes as bnb
574 except ImportError:
575 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
576
577 optimizer_class = bnb.optim.AdamW8bit
578 else:
579 optimizer_class = torch.optim.AdamW
580
553 # Initialize the optimizer 581 # Initialize the optimizer
554 optimizer = torch.optim.AdamW( 582 optimizer = optimizer_class(
555 unet.parameters(), # only optimize unet 583 unet.parameters(), # only optimize unet
556 lr=args.learning_rate, 584 lr=args.learning_rate,
557 betas=(args.adam_beta1, args.adam_beta2), 585 betas=(args.adam_beta1, args.adam_beta2),
@@ -705,8 +733,9 @@ def main():
705 noise = torch.randn(latents.shape).to(latents.device) 733 noise = torch.randn(latents.shape).to(latents.device)
706 bsz = latents.shape[0] 734 bsz = latents.shape[0]
707 # Sample a random timestep for each image 735 # Sample a random timestep for each image
708 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, 736 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
709 (bsz,), device=latents.device).long() 737 (bsz,), device=latents.device)
738 timesteps = timesteps.long()
710 739
711 # Add noise to the latents according to the noise magnitude at each timestep 740 # Add noise to the latents according to the noise magnitude at each timestep
712 # (this is the forward diffusion process) 741 # (this is the forward diffusion process)
@@ -719,14 +748,30 @@ def main():
719 # Predict the noise residual 748 # Predict the noise residual
720 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 749 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
721 750
722 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 751 if args.with_prior_preservation:
752 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
753 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
754 noise, noise_prior = torch.chunk(noise, 2, dim=0)
755
756 # Compute instance loss
757 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
758
759 # Compute prior loss
760 prior_loss = F.mse_loss(noise_pred_prior, noise_prior,
761 reduction="none").mean([1, 2, 3]).mean()
762
763 # Add the prior loss to the instance loss.
764 loss = loss + args.prior_loss_weight * prior_loss
765 else:
766 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
723 767
724 accelerator.backward(loss) 768 accelerator.backward(loss)
769 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
725 770
726 optimizer.step() 771 optimizer.step()
727 if not accelerator.optimizer_step_was_skipped: 772 if not accelerator.optimizer_step_was_skipped:
728 lr_scheduler.step() 773 lr_scheduler.step()
729 optimizer.zero_grad(set_to_none=True) 774 optimizer.zero_grad()
730 775
731 loss = loss.detach().item() 776 loss = loss.detach().item()
732 train_loss += loss 777 train_loss += loss
@@ -765,8 +810,9 @@ def main():
765 810
766 noise = torch.randn(latents.shape).to(latents.device) 811 noise = torch.randn(latents.shape).to(latents.device)
767 bsz = latents.shape[0] 812 bsz = latents.shape[0]
768 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, 813 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
769 (bsz,), device=latents.device).long() 814 (bsz,), device=latents.device)
815 timesteps = timesteps.long()
770 816
771 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 817 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
772 818
@@ -776,7 +822,18 @@ def main():
776 822
777 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 823 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
778 824
779 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 825 if args.with_prior_preservation:
826 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
827 noise, noise_prior = torch.chunk(noise, 2, dim=0)
828
829 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
830
831 prior_loss = F.mse_loss(noise_pred_prior, noise_prior,
832 reduction="none").mean([1, 2, 3]).mean()
833
834 loss = loss + args.prior_loss_weight * prior_loss
835 else:
836 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
780 837
781 loss = loss.detach().item() 838 loss = loss.detach().item()
782 val_loss += loss 839 val_loss += loss
diff --git a/environment.yaml b/environment.yaml
index 46a4388..c9f498e 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -19,6 +19,7 @@ dependencies:
19 - -e git+https://github.com/ShivamShrirao/diffusers#egg=diffusers 19 - -e git+https://github.com/ShivamShrirao/diffusers#egg=diffusers
20 - accelerate==0.12.0 20 - accelerate==0.12.0
21 - albumentations==1.1.0 21 - albumentations==1.1.0
22 - bitsandbytes==0.34.0
22 - einops==0.4.1 23 - einops==0.4.1
23 - imageio==2.22.0 24 - imageio==2.22.0
24 - kornia==0.6 25 - kornia==0.6