summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-02 20:13:59 +0100
committerVolpeon <git@volpeon.ink>2023-01-02 20:13:59 +0100
commit46d631759f59bc6b65458202641e5f5a9bc30b7b (patch)
treeea8c94ff336fe27b6cc8f39cea6c1699f44c61d5 /train_dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-46d631759f59bc6b65458202641e5f5a9bc30b7b.tar.gz
textual-inversion-diff-46d631759f59bc6b65458202641e5f5a9bc30b7b.tar.bz2
textual-inversion-diff-46d631759f59bc6b65458202641e5f5a9bc30b7b.zip
Fixed LR finder
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1e49474..218018b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -634,7 +634,7 @@ def main():
634 ) 634 )
635 635
636 if args.find_lr: 636 if args.find_lr:
637 args.learning_rate = 1e2 637 args.learning_rate = 1e-4
638 638
639 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 639 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
640 if args.use_8bit_adam: 640 if args.use_8bit_adam:
@@ -901,7 +901,7 @@ def main():
901 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), 901 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle),
902 on_eval=lambda: tokenizer.set_use_vector_shuffle(False) 902 on_eval=lambda: tokenizer.set_use_vector_shuffle(False)
903 ) 903 )
904 lr_finder.run(min_lr=1e-4) 904 lr_finder.run(end_lr=1e2)
905 905
906 plt.savefig(basepath.joinpath("lr.png")) 906 plt.savefig(basepath.joinpath("lr.png"))
907 plt.close() 907 plt.close()