diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-27 13:58:48 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-27 13:58:48 +0100 |
| commit | 6df1fc46daca9c289f1d7f7524e01deac5c92fd1 (patch) | |
| tree | 2ebac26cb0fd377a95437ee54b517011fed36eac /train_ti.py | |
| parent | Added validation phase to learn rate finder (diff) | |
| download | textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.gz textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.tar.bz2 textual-inversion-diff-6df1fc46daca9c289f1d7f7524e01deac5c92fd1.zip | |
Improved learning rate finder
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 10 |
1 files changed, 3 insertions, 7 deletions
diff --git a/train_ti.py b/train_ti.py index 32f44f4..870b2ba 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -548,9 +548,6 @@ def main(): | |||
| 548 | args.train_batch_size * accelerator.num_processes | 548 | args.train_batch_size * accelerator.num_processes |
| 549 | ) | 549 | ) |
| 550 | 550 | ||
| 551 | if args.find_lr: | ||
| 552 | args.learning_rate = 1e2 | ||
| 553 | |||
| 554 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 551 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 555 | if args.use_8bit_adam: | 552 | if args.use_8bit_adam: |
| 556 | try: | 553 | try: |
| @@ -783,7 +780,7 @@ def main(): | |||
| 783 | 780 | ||
| 784 | if args.find_lr: | 781 | if args.find_lr: |
| 785 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 782 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
| 786 | lr_finder.run(num_train_steps=2) | 783 | lr_finder.run(min_lr=1e-6, num_train_batches=4) |
| 787 | 784 | ||
| 788 | plt.savefig(basepath.joinpath("lr.png")) | 785 | plt.savefig(basepath.joinpath("lr.png")) |
| 789 | plt.close() | 786 | plt.close() |
| @@ -908,9 +905,8 @@ def main(): | |||
| 908 | avg_loss_val.update(loss.detach_(), bsz) | 905 | avg_loss_val.update(loss.detach_(), bsz) |
| 909 | avg_acc_val.update(acc.detach_(), bsz) | 906 | avg_acc_val.update(acc.detach_(), bsz) |
| 910 | 907 | ||
| 911 | if accelerator.sync_gradients: | 908 | local_progress_bar.update(1) |
| 912 | local_progress_bar.update(1) | 909 | global_progress_bar.update(1) |
| 913 | global_progress_bar.update(1) | ||
| 914 | 910 | ||
| 915 | logs = { | 911 | logs = { |
| 916 | "val/loss": avg_loss_val.avg.item(), | 912 | "val/loss": avg_loss_val.avg.item(), |
