summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/training/lr.py b/training/lr.py
index c0e9b3f..0c5ce9e 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -90,6 +90,7 @@ class LRFinder():
90 else: 90 else:
91 if smooth_f > 0: 91 if smooth_f > 0:
92 loss = smooth_f * loss + (1 - smooth_f) * losses[-1] 92 loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
93 acc = smooth_f * acc + (1 - smooth_f) * accs[-1]
93 if loss < best_loss: 94 if loss < best_loss:
94 best_loss = loss 95 best_loss = loss
95 if acc > best_acc: 96 if acc > best_acc:
@@ -132,9 +133,9 @@ class LRFinder():
132 ax_loss.set_xlabel("Learning rate") 133 ax_loss.set_xlabel("Learning rate")
133 ax_loss.set_ylabel("Loss") 134 ax_loss.set_ylabel("Loss")
134 135
135 # ax_acc = ax_loss.twinx() 136 ax_acc = ax_loss.twinx()
136 # ax_acc.plot(lrs, accs, color='blue') 137 ax_acc.plot(lrs, accs, color='blue')
137 # ax_acc.set_ylabel("Accuracy") 138 ax_acc.set_ylabel("Accuracy")
138 139
139 print("LR suggestion: steepest gradient") 140 print("LR suggestion: steepest gradient")
140 min_grad_idx = None 141 min_grad_idx = None