summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 21:28:52 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 21:28:52 +0200
commit46b6c09a18b41edff77c6881529b66733d788abe (patch)
tree670e7cdda37ba7a010b570398a63dd38e357b6ce /textual_inversion.py
parentSmall perf improvements (diff)
downloadtextual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.gz
textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.bz2
textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.zip
Dreambooth: Generate specialized class images from input prompts
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 5fc2338..4c4da29 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -694,8 +694,7 @@ def main():
694 # Predict the noise residual 694 # Predict the noise residual
695 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 695 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
696 696
697 with accelerator.autocast(): 697 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
698 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
699 698
700 accelerator.backward(loss) 699 accelerator.backward(loss)
701 700
@@ -766,8 +765,7 @@ def main():
766 765
767 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 766 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
768 767
769 with accelerator.autocast(): 768 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
770 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
771 769
772 loss = loss.detach().item() 770 loss = loss.detach().item()
773 val_loss += loss 771 val_loss += loss