diff options
author | Volpeon <git@volpeon.ink> | 2022-12-27 13:58:48 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-27 13:58:48 +0100 |
commit | 6df1fc46daca9c289f1d7f7524e01deac5c92fd1 (patch) | |
tree | 2ebac26cb0fd377a95437ee54b517011fed36eac /train_dreambooth.py | |
parent | Added validation phase to learn rate finder (diff) | |
download | textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.gz textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.bz2 textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.zip |
Improved learning rate finder
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index a62cec9..325fe90 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -970,9 +970,8 @@ def main(): | |||
970 | avg_loss_val.update(loss.detach_(), bsz) | 970 | avg_loss_val.update(loss.detach_(), bsz) |
971 | avg_acc_val.update(acc.detach_(), bsz) | 971 | avg_acc_val.update(acc.detach_(), bsz) |
972 | 972 | ||
973 | if accelerator.sync_gradients: | 973 | local_progress_bar.update(1) |
974 | local_progress_bar.update(1) | 974 | global_progress_bar.update(1) |
975 | global_progress_bar.update(1) | ||
976 | 975 | ||
977 | logs = { | 976 | logs = { |
978 | "val/loss": avg_loss_val.avg.item(), | 977 | "val/loss": avg_loss_val.avg.item(), |