From 4ecc62e4dd854dce683dffa040677f609bf3a33d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 18:25:24 +0200 Subject: Incorporate upstream changes --- dreambooth.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 8 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 45a0497..c01cbe3 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -135,6 +135,11 @@ def parse_args(): default=0, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument( "--adam_beta1", type=float, @@ -237,12 +242,24 @@ def parse_args(): default=None, help="The prompt to specify images in the same class as provided intance images.", ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss." + ) parser.add_argument( "--with_prior_preservation", default=False, action="store_true", help="Flag to add prior perservation loss.", ) + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm." + ) parser.add_argument( "--num_class_images", type=int, @@ -550,8 +567,19 @@ def main(): args.train_batch_size * accelerator.num_processes ) + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + # Initialize the optimizer - optimizer = torch.optim.AdamW( + optimizer = optimizer_class( unet.parameters(), # only optimize unet lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), @@ -705,8 +733,9 @@ def main(): noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, - (bsz,), device=latents.device).long() + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -719,14 +748,30 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + if args.with_prior_preservation: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, + reduction="none").mean([1, 2, 3]).mean() + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + optimizer.zero_grad() loss = loss.detach().item() train_loss += loss @@ -765,8 +810,9 @@ def main(): noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, - (bsz,), device=latents.device).long() + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -776,7 +822,18 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + if args.with_prior_preservation: + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) + + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, + reduction="none").mean([1, 2, 3]).mean() + + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss -- cgit v1.2.3-54-g00ecf