diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-03 21:28:52 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-03 21:28:52 +0200 |
| commit | 46b6c09a18b41edff77c6881529b66733d788abe (patch) | |
| tree | 670e7cdda37ba7a010b570398a63dd38e357b6ce /textual_inversion.py | |
| parent | Small perf improvements (diff) | |
| download | textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.gz textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.bz2 textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.zip | |
Dreambooth: Generate specialized class images from input prompts
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 5fc2338..4c4da29 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -694,8 +694,7 @@ def main(): | |||
| 694 | # Predict the noise residual | 694 | # Predict the noise residual |
| 695 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 695 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 696 | 696 | ||
| 697 | with accelerator.autocast(): | 697 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 698 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 699 | 698 | ||
| 700 | accelerator.backward(loss) | 699 | accelerator.backward(loss) |
| 701 | 700 | ||
| @@ -766,8 +765,7 @@ def main(): | |||
| 766 | 765 | ||
| 767 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 766 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 768 | 767 | ||
| 769 | with accelerator.autocast(): | 768 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 770 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 771 | 769 | ||
| 772 | loss = loss.detach().item() | 770 | loss = loss.detach().item() |
| 773 | val_loss += loss | 771 | val_loss += loss |
