From 633d890e4964e070be9b0a5b299c2f2e51d4b055 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Oct 2022 12:27:53 +0200 Subject: Upstream updates; better handling of textual embedding --- dreambooth.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 42d3980..770ad38 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -430,7 +430,7 @@ class Checkpointer: eta=eta, num_inference_steps=num_inference_steps, output_type='pil' - )["sample"] + ).images all_samples += samples @@ -537,6 +537,12 @@ def main(): num_train_timesteps=args.noise_timesteps ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + def collate_fn(examples): prompts = [example["prompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] @@ -549,7 +555,7 @@ def main(): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format) + pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids @@ -651,8 +657,8 @@ def main(): ) # Move text_encoder and vae to device - text_encoder.to(accelerator.device) - vae.to(accelerator.device) + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) # Keep text_encoder and vae in eval mode as we don't train these text_encoder.eval() @@ -738,7 +744,7 @@ def main(): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, @@ -761,15 +767,15 @@ def main(): noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: @@ -818,7 +824,7 @@ def main(): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -832,7 +838,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") loss = loss.detach().item() val_loss += loss -- cgit v1.2.3-54-g00ecf