summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-02 20:13:59 +0100
committerVolpeon <git@volpeon.ink>2023-01-02 20:13:59 +0100
commit46d631759f59bc6b65458202641e5f5a9bc30b7b (patch)
treeea8c94ff336fe27b6cc8f39cea6c1699f44c61d5 /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-46d631759f59bc6b65458202641e5f5a9bc30b7b.tar.gz
textual-inversion-diff-46d631759f59bc6b65458202641e5f5a9bc30b7b.tar.bz2
textual-inversion-diff-46d631759f59bc6b65458202641e5f5a9bc30b7b.zip
Fixed LR finder
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 2b3f017..102c0fa 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -584,7 +584,7 @@ def main():
584 ) 584 )
585 585
586 if args.find_lr: 586 if args.find_lr:
587 args.learning_rate = 1e2 587 args.learning_rate = 1e-4
588 588
589 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 589 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
590 if args.use_8bit_adam: 590 if args.use_8bit_adam:
@@ -853,9 +853,9 @@ def main():
853 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), 853 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle),
854 on_eval=lambda: tokenizer.set_use_vector_shuffle(False) 854 on_eval=lambda: tokenizer.set_use_vector_shuffle(False)
855 ) 855 )
856 lr_finder.run(min_lr=1e-4) 856 lr_finder.run(end_lr=1e2)
857 857
858 plt.savefig(basepath.joinpath("lr.png")) 858 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
859 plt.close() 859 plt.close()
860 860
861 quit() 861 quit()