From 0995b6d9b16c3b0ac4971e0d2ef4cf8f3ee050e8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 22 Oct 2022 11:05:12 +0200 Subject: Training update --- dreambooth.py | 161 +++++++++++++++++++++++++++------------------------------- 1 file changed, 75 insertions(+), 86 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 72c56cd..1539e81 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -810,6 +810,75 @@ def main(): ) global_progress_bar.set_description("Total progress") + def run_step(batch, train=False, class_images=False): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if class_images: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + if train: + accelerator.backward(loss) + + if args.initializer_token is not None: + # Keep the token embeddings fixed except the newly added + # embeddings for the concept, as we only want to optimize the concept embeddings + if accelerator.num_processes > 1: + token_embeds = text_encoder.module.get_input_embeddings().weight + else: + token_embeds = text_encoder.get_input_embeddings().weight + + # Get the index for tokens that we want to freeze + index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id + token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] + + if accelerator.sync_gradients: + params_to_clip = ( + unet.parameters() + if args.initializer_token is not None + else itertools.chain(unet.parameters(), text_encoder.parameters()) + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + loss = loss.detach().item() + return loss + try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -823,72 +892,11 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(itertools.chain(unet, text_encoder)): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) - - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.num_class_images != 0: - # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() - - # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") - - accelerator.backward(loss) - - if args.initializer_token is not None: - # Keep the token embeddings fixed except the newly added - # embeddings for the concept, as we only want to optimize the concept embeddings - if accelerator.num_processes > 1: - token_embeds = text_encoder.module.get_input_embeddings().weight - else: - token_embeds = text_encoder.get_input_embeddings().weight - - # Get the index for tokens that we want to freeze - index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id - token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] - - if accelerator.sync_gradients: - params_to_clip = ( - unet.parameters() - if args.initializer_token is not None - else itertools.chain(unet.parameters(), text_encoder.parameters()) - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - loss = loss.detach().item() + loss = run_step( + batch, + train=True, + class_images=args.num_class_images != 0 + ) train_loss += loss # Checks if the accelerator has performed an optimization step behind the scenes @@ -929,26 +937,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - noise = torch.randn_like(latents) - bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) - - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") - - loss = loss.detach().item() + loss = run_step(batch) val_loss += loss if accelerator.sync_gradients: -- cgit v1.2.3-54-g00ecf