From 9d6c75262b6919758e781b8333428861a5bf7ede Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Dec 2022 11:02:49 +0100 Subject: Added learning rate finder --- train_dreambooth.py | 129 ++++++++++++++++++++++------------------------------ 1 file changed, 54 insertions(+), 75 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 08bc9e0..a62cec9 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -843,6 +843,58 @@ def main(): ) global_progress_bar.set_description("Total progress") + def loop(batch): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noisy_latents.to(dtype=unet.dtype) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.num_class_images != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + acc = (model_pred == latents).float().mean() + + return loss, acc, bsz + try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -859,54 +911,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - noisy_latents = noisy_latents.to(dtype=unet.dtype) - - # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - - # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.num_class_images != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == latents).float().mean() + loss, acc, bsz = loop(batch) accelerator.backward(loss) @@ -960,33 +965,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - noise = torch.randn_like(latents) - bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - noisy_latents = noisy_latents.to(dtype=unet.dtype) - - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == latents).float().mean() + loss, acc, bsz = loop(batch) avg_loss_val.update(loss.detach_(), bsz) avg_acc_val.update(acc.detach_(), bsz) -- cgit v1.2.3-54-g00ecf