diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 5fc2338..4c4da29 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -694,8 +694,7 @@ def main(): | |||
694 | # Predict the noise residual | 694 | # Predict the noise residual |
695 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 695 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
696 | 696 | ||
697 | with accelerator.autocast(): | 697 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
698 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
699 | 698 | ||
700 | accelerator.backward(loss) | 699 | accelerator.backward(loss) |
701 | 700 | ||
@@ -766,8 +765,7 @@ def main(): | |||
766 | 765 | ||
767 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 766 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
768 | 767 | ||
769 | with accelerator.autocast(): | 768 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
770 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
771 | 769 | ||
772 | loss = loss.detach().item() | 770 | loss = loss.detach().item() |
773 | val_loss += loss | 771 | val_loss += loss |