diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 21:28:52 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 21:28:52 +0200 |
commit | 46b6c09a18b41edff77c6881529b66733d788abe (patch) | |
tree | 670e7cdda37ba7a010b570398a63dd38e357b6ce /textual_inversion.py | |
parent | Small perf improvements (diff) | |
download | textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.gz textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.bz2 textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.zip |
Dreambooth: Generate specialized class images from input prompts
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 5fc2338..4c4da29 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -694,8 +694,7 @@ def main(): | |||
694 | # Predict the noise residual | 694 | # Predict the noise residual |
695 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 695 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
696 | 696 | ||
697 | with accelerator.autocast(): | 697 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
698 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
699 | 698 | ||
700 | accelerator.backward(loss) | 699 | accelerator.backward(loss) |
701 | 700 | ||
@@ -766,8 +765,7 @@ def main(): | |||
766 | 765 | ||
767 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 766 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
768 | 767 | ||
769 | with accelerator.autocast(): | 768 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
770 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
771 | 769 | ||
772 | loss = loss.detach().item() | 770 | loss = loss.detach().item() |
773 | val_loss += loss | 771 | val_loss += loss |