summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
committerVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
commitdfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch)
treeda07cbadfad6f54e55e43e2fda21cef80cded5ea /train_lora.py
parentUpdate (diff)
downloadtextual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip
Training script improvements
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py
index 9a42cae..de878a4 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -810,7 +810,7 @@ def main():
810 target, target_prior = torch.chunk(target, 2, dim=0) 810 target, target_prior = torch.chunk(target, 2, dim=0)
811 811
812 # Compute instance loss 812 # Compute instance loss
813 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 813 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
814 814
815 # Compute prior loss 815 # Compute prior loss
816 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 816 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")