summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-27 13:58:48 +0100
committerVolpeon <git@volpeon.ink>2022-12-27 13:58:48 +0100
commit6df1fc46daca9c289f1d7f7524e01deac5c92fd1 (patch)
tree2ebac26cb0fd377a95437ee54b517011fed36eac /train_ti.py
parentAdded validation phase to learn rate finder (diff)
downloadtextual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.gz
textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.bz2
textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.zip
Improved learning rate finder
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py10
1 files changed, 3 insertions, 7 deletions
diff --git a/train_ti.py b/train_ti.py
index 32f44f4..870b2ba 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -548,9 +548,6 @@ def main():
548 args.train_batch_size * accelerator.num_processes 548 args.train_batch_size * accelerator.num_processes
549 ) 549 )
550 550
551 if args.find_lr:
552 args.learning_rate = 1e2
553
554 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 551 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
555 if args.use_8bit_adam: 552 if args.use_8bit_adam:
556 try: 553 try:
@@ -783,7 +780,7 @@ def main():
783 780
784 if args.find_lr: 781 if args.find_lr:
785 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 782 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
786 lr_finder.run(num_train_steps=2) 783 lr_finder.run(min_lr=1e-6, num_train_batches=4)
787 784
788 plt.savefig(basepath.joinpath("lr.png")) 785 plt.savefig(basepath.joinpath("lr.png"))
789 plt.close() 786 plt.close()
@@ -908,9 +905,8 @@ def main():
908 avg_loss_val.update(loss.detach_(), bsz) 905 avg_loss_val.update(loss.detach_(), bsz)
909 avg_acc_val.update(acc.detach_(), bsz) 906 avg_acc_val.update(acc.detach_(), bsz)
910 907
911 if accelerator.sync_gradients: 908 local_progress_bar.update(1)
912 local_progress_bar.update(1) 909 global_progress_bar.update(1)
913 global_progress_bar.update(1)
914 910
915 logs = { 911 logs = {
916 "val/loss": avg_loss_val.avg.item(), 912 "val/loss": avg_loss_val.avg.item(),