diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-17 12:27:53 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-17 12:27:53 +0200 | 
| commit | 633d890e4964e070be9b0a5b299c2f2e51d4b055 (patch) | |
| tree | 235b33195b041e45bb7a6a24471ea55ad4bd7850 /dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-633d890e4964e070be9b0a5b299c2f2e51d4b055.tar.gz textual-inversion-diff-633d890e4964e070be9b0a5b299c2f2e51d4b055.tar.bz2 textual-inversion-diff-633d890e4964e070be9b0a5b299c2f2e51d4b055.zip | |
Upstream updates; better handling of textual embedding
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 26 | 
1 files changed, 16 insertions, 10 deletions
| diff --git a/dreambooth.py b/dreambooth.py index 42d3980..770ad38 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -430,7 +430,7 @@ class Checkpointer: | |||
| 430 | eta=eta, | 430 | eta=eta, | 
| 431 | num_inference_steps=num_inference_steps, | 431 | num_inference_steps=num_inference_steps, | 
| 432 | output_type='pil' | 432 | output_type='pil' | 
| 433 | )["sample"] | 433 | ).images | 
| 434 | 434 | ||
| 435 | all_samples += samples | 435 | all_samples += samples | 
| 436 | 436 | ||
| @@ -537,6 +537,12 @@ def main(): | |||
| 537 | num_train_timesteps=args.noise_timesteps | 537 | num_train_timesteps=args.noise_timesteps | 
| 538 | ) | 538 | ) | 
| 539 | 539 | ||
| 540 | weight_dtype = torch.float32 | ||
| 541 | if args.mixed_precision == "fp16": | ||
| 542 | weight_dtype = torch.float16 | ||
| 543 | elif args.mixed_precision == "bf16": | ||
| 544 | weight_dtype = torch.bfloat16 | ||
| 545 | |||
| 540 | def collate_fn(examples): | 546 | def collate_fn(examples): | 
| 541 | prompts = [example["prompts"] for example in examples] | 547 | prompts = [example["prompts"] for example in examples] | 
| 542 | nprompts = [example["nprompts"] for example in examples] | 548 | nprompts = [example["nprompts"] for example in examples] | 
| @@ -549,7 +555,7 @@ def main(): | |||
| 549 | pixel_values += [example["class_images"] for example in examples] | 555 | pixel_values += [example["class_images"] for example in examples] | 
| 550 | 556 | ||
| 551 | pixel_values = torch.stack(pixel_values) | 557 | pixel_values = torch.stack(pixel_values) | 
| 552 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format) | 558 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 
| 553 | 559 | ||
| 554 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 560 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 
| 555 | 561 | ||
| @@ -651,8 +657,8 @@ def main(): | |||
| 651 | ) | 657 | ) | 
| 652 | 658 | ||
| 653 | # Move text_encoder and vae to device | 659 | # Move text_encoder and vae to device | 
| 654 | text_encoder.to(accelerator.device) | 660 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 
| 655 | vae.to(accelerator.device) | 661 | vae.to(accelerator.device, dtype=weight_dtype) | 
| 656 | 662 | ||
| 657 | # Keep text_encoder and vae in eval mode as we don't train these | 663 | # Keep text_encoder and vae in eval mode as we don't train these | 
| 658 | text_encoder.eval() | 664 | text_encoder.eval() | 
| @@ -738,7 +744,7 @@ def main(): | |||
| 738 | latents = latents * 0.18215 | 744 | latents = latents * 0.18215 | 
| 739 | 745 | ||
| 740 | # Sample noise that we'll add to the latents | 746 | # Sample noise that we'll add to the latents | 
| 741 | noise = torch.randn(latents.shape).to(latents.device) | 747 | noise = torch.randn_like(latents) | 
| 742 | bsz = latents.shape[0] | 748 | bsz = latents.shape[0] | 
| 743 | # Sample a random timestep for each image | 749 | # Sample a random timestep for each image | 
| 744 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 750 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 
| @@ -761,15 +767,15 @@ def main(): | |||
| 761 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 767 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 
| 762 | 768 | ||
| 763 | # Compute instance loss | 769 | # Compute instance loss | 
| 764 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 770 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | 
| 765 | 771 | ||
| 766 | # Compute prior loss | 772 | # Compute prior loss | 
| 767 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 773 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | 
| 768 | 774 | ||
| 769 | # Add the prior loss to the instance loss. | 775 | # Add the prior loss to the instance loss. | 
| 770 | loss = loss + args.prior_loss_weight * prior_loss | 776 | loss = loss + args.prior_loss_weight * prior_loss | 
| 771 | else: | 777 | else: | 
| 772 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 778 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 
| 773 | 779 | ||
| 774 | accelerator.backward(loss) | 780 | accelerator.backward(loss) | 
| 775 | if accelerator.sync_gradients: | 781 | if accelerator.sync_gradients: | 
| @@ -818,7 +824,7 @@ def main(): | |||
| 818 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 824 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 
| 819 | latents = latents * 0.18215 | 825 | latents = latents * 0.18215 | 
| 820 | 826 | ||
| 821 | noise = torch.randn(latents.shape).to(latents.device) | 827 | noise = torch.randn_like(latents) | 
| 822 | bsz = latents.shape[0] | 828 | bsz = latents.shape[0] | 
| 823 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 829 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 
| 824 | (bsz,), device=latents.device) | 830 | (bsz,), device=latents.device) | 
| @@ -832,7 +838,7 @@ def main(): | |||
| 832 | 838 | ||
| 833 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 839 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 
| 834 | 840 | ||
| 835 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 841 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 
| 836 | 842 | ||
| 837 | loss = loss.detach().item() | 843 | loss = loss.detach().item() | 
| 838 | val_loss += loss | 844 | val_loss += loss | 
