summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-29 09:00:19 +0100
committerVolpeon <git@volpeon.ink>2022-12-29 09:00:19 +0100
commit4d3d318a4168ef79847737cef2c0ad8a4dafd3e7 (patch)
tree967e2c1ee6e2c29b9b6ffaff3e8978f4a43a529d /training/lr.py
parentUpdated 1-cycle scheduler (diff)
downloadtextual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.gz
textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.bz2
textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.zip
Training improvements
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/training/lr.py b/training/lr.py
index c0e9b3f..0c5ce9e 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -90,6 +90,7 @@ class LRFinder():
90 else: 90 else:
91 if smooth_f > 0: 91 if smooth_f > 0:
92 loss = smooth_f * loss + (1 - smooth_f) * losses[-1] 92 loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
93 acc = smooth_f * acc + (1 - smooth_f) * accs[-1]
93 if loss < best_loss: 94 if loss < best_loss:
94 best_loss = loss 95 best_loss = loss
95 if acc > best_acc: 96 if acc > best_acc:
@@ -132,9 +133,9 @@ class LRFinder():
132 ax_loss.set_xlabel("Learning rate") 133 ax_loss.set_xlabel("Learning rate")
133 ax_loss.set_ylabel("Loss") 134 ax_loss.set_ylabel("Loss")
134 135
135 # ax_acc = ax_loss.twinx() 136 ax_acc = ax_loss.twinx()
136 # ax_acc.plot(lrs, accs, color='blue') 137 ax_acc.plot(lrs, accs, color='blue')
137 # ax_acc.set_ylabel("Accuracy") 138 ax_acc.set_ylabel("Accuracy")
138 139
139 print("LR suggestion: steepest gradient") 140 print("LR suggestion: steepest gradient")
140 min_grad_idx = None 141 min_grad_idx = None