diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/dreambooth.py b/dreambooth.py index 42d3980..770ad38 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -430,7 +430,7 @@ class Checkpointer: | |||
430 | eta=eta, | 430 | eta=eta, |
431 | num_inference_steps=num_inference_steps, | 431 | num_inference_steps=num_inference_steps, |
432 | output_type='pil' | 432 | output_type='pil' |
433 | )["sample"] | 433 | ).images |
434 | 434 | ||
435 | all_samples += samples | 435 | all_samples += samples |
436 | 436 | ||
@@ -537,6 +537,12 @@ def main(): | |||
537 | num_train_timesteps=args.noise_timesteps | 537 | num_train_timesteps=args.noise_timesteps |
538 | ) | 538 | ) |
539 | 539 | ||
540 | weight_dtype = torch.float32 | ||
541 | if args.mixed_precision == "fp16": | ||
542 | weight_dtype = torch.float16 | ||
543 | elif args.mixed_precision == "bf16": | ||
544 | weight_dtype = torch.bfloat16 | ||
545 | |||
540 | def collate_fn(examples): | 546 | def collate_fn(examples): |
541 | prompts = [example["prompts"] for example in examples] | 547 | prompts = [example["prompts"] for example in examples] |
542 | nprompts = [example["nprompts"] for example in examples] | 548 | nprompts = [example["nprompts"] for example in examples] |
@@ -549,7 +555,7 @@ def main(): | |||
549 | pixel_values += [example["class_images"] for example in examples] | 555 | pixel_values += [example["class_images"] for example in examples] |
550 | 556 | ||
551 | pixel_values = torch.stack(pixel_values) | 557 | pixel_values = torch.stack(pixel_values) |
552 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format) | 558 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
553 | 559 | ||
554 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 560 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
555 | 561 | ||
@@ -651,8 +657,8 @@ def main(): | |||
651 | ) | 657 | ) |
652 | 658 | ||
653 | # Move text_encoder and vae to device | 659 | # Move text_encoder and vae to device |
654 | text_encoder.to(accelerator.device) | 660 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
655 | vae.to(accelerator.device) | 661 | vae.to(accelerator.device, dtype=weight_dtype) |
656 | 662 | ||
657 | # Keep text_encoder and vae in eval mode as we don't train these | 663 | # Keep text_encoder and vae in eval mode as we don't train these |
658 | text_encoder.eval() | 664 | text_encoder.eval() |
@@ -738,7 +744,7 @@ def main(): | |||
738 | latents = latents * 0.18215 | 744 | latents = latents * 0.18215 |
739 | 745 | ||
740 | # Sample noise that we'll add to the latents | 746 | # Sample noise that we'll add to the latents |
741 | noise = torch.randn(latents.shape).to(latents.device) | 747 | noise = torch.randn_like(latents) |
742 | bsz = latents.shape[0] | 748 | bsz = latents.shape[0] |
743 | # Sample a random timestep for each image | 749 | # Sample a random timestep for each image |
744 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 750 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
@@ -761,15 +767,15 @@ def main(): | |||
761 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 767 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
762 | 768 | ||
763 | # Compute instance loss | 769 | # Compute instance loss |
764 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 770 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() |
765 | 771 | ||
766 | # Compute prior loss | 772 | # Compute prior loss |
767 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 773 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") |
768 | 774 | ||
769 | # Add the prior loss to the instance loss. | 775 | # Add the prior loss to the instance loss. |
770 | loss = loss + args.prior_loss_weight * prior_loss | 776 | loss = loss + args.prior_loss_weight * prior_loss |
771 | else: | 777 | else: |
772 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 778 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
773 | 779 | ||
774 | accelerator.backward(loss) | 780 | accelerator.backward(loss) |
775 | if accelerator.sync_gradients: | 781 | if accelerator.sync_gradients: |
@@ -818,7 +824,7 @@ def main(): | |||
818 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 824 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
819 | latents = latents * 0.18215 | 825 | latents = latents * 0.18215 |
820 | 826 | ||
821 | noise = torch.randn(latents.shape).to(latents.device) | 827 | noise = torch.randn_like(latents) |
822 | bsz = latents.shape[0] | 828 | bsz = latents.shape[0] |
823 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 829 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
824 | (bsz,), device=latents.device) | 830 | (bsz,), device=latents.device) |
@@ -832,7 +838,7 @@ def main(): | |||
832 | 838 | ||
833 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 839 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
834 | 840 | ||
835 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 841 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
836 | 842 | ||
837 | loss = loss.detach().item() | 843 | loss = loss.detach().item() |
838 | val_loss += loss | 844 | val_loss += loss |