summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-27 13:58:48 +0100
committerVolpeon <git@volpeon.ink>2022-12-27 13:58:48 +0100
commit6df1fc46daca9c289f1d7f7524e01deac5c92fd1 (patch)
tree2ebac26cb0fd377a95437ee54b517011fed36eac /train_dreambooth.py
parentAdded validation phase to learn rate finder (diff)
downloadtextual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.gz
textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.bz2
textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.zip
Improved learning rate finder
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index a62cec9..325fe90 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -970,9 +970,8 @@ def main():
970 avg_loss_val.update(loss.detach_(), bsz) 970 avg_loss_val.update(loss.detach_(), bsz)
971 avg_acc_val.update(acc.detach_(), bsz) 971 avg_acc_val.update(acc.detach_(), bsz)
972 972
973 if accelerator.sync_gradients: 973 local_progress_bar.update(1)
974 local_progress_bar.update(1) 974 global_progress_bar.update(1)
975 global_progress_bar.update(1)
976 975
977 logs = { 976 logs = {
978 "val/loss": avg_loss_val.avg.item(), 977 "val/loss": avg_loss_val.avg.item(),