summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 5fc2338..4c4da29 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -694,8 +694,7 @@ def main():
694 # Predict the noise residual 694 # Predict the noise residual
695 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 695 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
696 696
697 with accelerator.autocast(): 697 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
698 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
699 698
700 accelerator.backward(loss) 699 accelerator.backward(loss)
701 700
@@ -766,8 +765,7 @@ def main():
766 765
767 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 766 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
768 767
769 with accelerator.autocast(): 768 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
770 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
771 769
772 loss = loss.detach().item() 770 loss = loss.detach().item()
773 val_loss += loss 771 val_loss += loss