summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py26
1 files changed, 16 insertions, 10 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 42d3980..770ad38 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -430,7 +430,7 @@ class Checkpointer:
430 eta=eta, 430 eta=eta,
431 num_inference_steps=num_inference_steps, 431 num_inference_steps=num_inference_steps,
432 output_type='pil' 432 output_type='pil'
433 )["sample"] 433 ).images
434 434
435 all_samples += samples 435 all_samples += samples
436 436
@@ -537,6 +537,12 @@ def main():
537 num_train_timesteps=args.noise_timesteps 537 num_train_timesteps=args.noise_timesteps
538 ) 538 )
539 539
540 weight_dtype = torch.float32
541 if args.mixed_precision == "fp16":
542 weight_dtype = torch.float16
543 elif args.mixed_precision == "bf16":
544 weight_dtype = torch.bfloat16
545
540 def collate_fn(examples): 546 def collate_fn(examples):
541 prompts = [example["prompts"] for example in examples] 547 prompts = [example["prompts"] for example in examples]
542 nprompts = [example["nprompts"] for example in examples] 548 nprompts = [example["nprompts"] for example in examples]
@@ -549,7 +555,7 @@ def main():
549 pixel_values += [example["class_images"] for example in examples] 555 pixel_values += [example["class_images"] for example in examples]
550 556
551 pixel_values = torch.stack(pixel_values) 557 pixel_values = torch.stack(pixel_values)
552 pixel_values = pixel_values.to(memory_format=torch.contiguous_format) 558 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
553 559
554 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 560 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
555 561
@@ -651,8 +657,8 @@ def main():
651 ) 657 )
652 658
653 # Move text_encoder and vae to device 659 # Move text_encoder and vae to device
654 text_encoder.to(accelerator.device) 660 text_encoder.to(accelerator.device, dtype=weight_dtype)
655 vae.to(accelerator.device) 661 vae.to(accelerator.device, dtype=weight_dtype)
656 662
657 # Keep text_encoder and vae in eval mode as we don't train these 663 # Keep text_encoder and vae in eval mode as we don't train these
658 text_encoder.eval() 664 text_encoder.eval()
@@ -738,7 +744,7 @@ def main():
738 latents = latents * 0.18215 744 latents = latents * 0.18215
739 745
740 # Sample noise that we'll add to the latents 746 # Sample noise that we'll add to the latents
741 noise = torch.randn(latents.shape).to(latents.device) 747 noise = torch.randn_like(latents)
742 bsz = latents.shape[0] 748 bsz = latents.shape[0]
743 # Sample a random timestep for each image 749 # Sample a random timestep for each image
744 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 750 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
@@ -761,15 +767,15 @@ def main():
761 noise, noise_prior = torch.chunk(noise, 2, dim=0) 767 noise, noise_prior = torch.chunk(noise, 2, dim=0)
762 768
763 # Compute instance loss 769 # Compute instance loss
764 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 770 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
765 771
766 # Compute prior loss 772 # Compute prior loss
767 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() 773 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
768 774
769 # Add the prior loss to the instance loss. 775 # Add the prior loss to the instance loss.
770 loss = loss + args.prior_loss_weight * prior_loss 776 loss = loss + args.prior_loss_weight * prior_loss
771 else: 777 else:
772 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 778 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
773 779
774 accelerator.backward(loss) 780 accelerator.backward(loss)
775 if accelerator.sync_gradients: 781 if accelerator.sync_gradients:
@@ -818,7 +824,7 @@ def main():
818 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 824 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
819 latents = latents * 0.18215 825 latents = latents * 0.18215
820 826
821 noise = torch.randn(latents.shape).to(latents.device) 827 noise = torch.randn_like(latents)
822 bsz = latents.shape[0] 828 bsz = latents.shape[0]
823 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 829 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
824 (bsz,), device=latents.device) 830 (bsz,), device=latents.device)
@@ -832,7 +838,7 @@ def main():
832 838
833 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 839 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
834 840
835 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 841 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
836 842
837 loss = loss.detach().item() 843 loss = loss.detach().item()
838 val_loss += loss 844 val_loss += loss