diff options
author | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
commit | dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch) | |
tree | da07cbadfad6f54e55e43e2fda21cef80cded5ea /train_lora.py | |
parent | Update (diff) | |
download | textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2 textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip |
Training script improvements
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index 9a42cae..de878a4 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -810,7 +810,7 @@ def main(): | |||
810 | target, target_prior = torch.chunk(target, 2, dim=0) | 810 | target, target_prior = torch.chunk(target, 2, dim=0) |
811 | 811 | ||
812 | # Compute instance loss | 812 | # Compute instance loss |
813 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 813 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
814 | 814 | ||
815 | # Compute prior loss | 815 | # Compute prior loss |
816 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 816 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |