summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-02 20:13:59 +0100
committerVolpeon <git@volpeon.ink>2023-01-02 20:13:59 +0100
commit46d631759f59bc6b65458202641e5f5a9bc30b7b (patch)
treeea8c94ff336fe27b6cc8f39cea6c1699f44c61d5 /training
parentUpdate (diff)
downloadtextual-inversion-diff-46d631759f59bc6b65458202641e5f5a9bc30b7b.tar.gz
textual-inversion-diff-46d631759f59bc6b65458202641e5f5a9bc30b7b.tar.bz2
textual-inversion-diff-46d631759f59bc6b65458202641e5f5a9bc30b7b.zip
Fixed LR finder
Diffstat (limited to 'training')
-rw-r--r--training/lr.py38
1 files changed, 21 insertions, 17 deletions
diff --git a/training/lr.py b/training/lr.py
index fe166ed..acc01a2 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,6 +1,7 @@
1import math 1import math
2import copy 2import copy
3from typing import Callable 3from typing import Callable
4from functools import partial
4 5
5import matplotlib.pyplot as plt 6import matplotlib.pyplot as plt
6import numpy as np 7import numpy as np
@@ -41,7 +42,7 @@ class LRFinder():
41 42
42 def run( 43 def run(
43 self, 44 self,
44 min_lr, 45 end_lr,
45 skip_start: int = 10, 46 skip_start: int = 10,
46 skip_end: int = 5, 47 skip_end: int = 5,
47 num_epochs: int = 100, 48 num_epochs: int = 100,
@@ -57,7 +58,7 @@ class LRFinder():
57 losses = [] 58 losses = []
58 accs = [] 59 accs = []
59 60
60 lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) 61 lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs)
61 62
62 steps = min(num_train_batches, len(self.train_dataloader)) 63 steps = min(num_train_batches, len(self.train_dataloader))
63 steps += min(num_val_batches, len(self.val_dataloader)) 64 steps += min(num_val_batches, len(self.val_dataloader))
@@ -152,29 +153,30 @@ class LRFinder():
152 print("Stopping early, the loss has diverged") 153 print("Stopping early, the loss has diverged")
153 break 154 break
154 155
155 if skip_end == 0:
156 lrs = lrs[skip_start:]
157 losses = losses[skip_start:]
158 accs = accs[skip_start:]
159 else:
160 lrs = lrs[skip_start:-skip_end]
161 losses = losses[skip_start:-skip_end]
162 accs = accs[skip_start:-skip_end]
163
164 fig, ax_loss = plt.subplots() 156 fig, ax_loss = plt.subplots()
157 ax_acc = ax_loss.twinx()
165 158
166 ax_loss.plot(lrs, losses, color='red') 159 ax_loss.plot(lrs, losses, color='red')
167 ax_loss.set_xscale("log") 160 ax_loss.set_xscale("log")
168 ax_loss.set_xlabel("Learning rate") 161 ax_loss.set_xlabel(f"Learning rate")
169 ax_loss.set_ylabel("Loss") 162 ax_loss.set_ylabel("Loss")
170 163
171 ax_acc = ax_loss.twinx()
172 ax_acc.plot(lrs, accs, color='blue') 164 ax_acc.plot(lrs, accs, color='blue')
165 ax_acc.set_xscale("log")
173 ax_acc.set_ylabel("Accuracy") 166 ax_acc.set_ylabel("Accuracy")
174 167
175 print("LR suggestion: steepest gradient") 168 print("LR suggestion: steepest gradient")
176 min_grad_idx = None 169 min_grad_idx = None
177 170
171 if skip_end == 0:
172 lrs = lrs[skip_start:]
173 losses = losses[skip_start:]
174 accs = accs[skip_start:]
175 else:
176 lrs = lrs[skip_start:-skip_end]
177 losses = losses[skip_start:-skip_end]
178 accs = accs[skip_start:-skip_end]
179
178 try: 180 try:
179 min_grad_idx = (np.gradient(np.array(losses))).argmin() 181 min_grad_idx = (np.gradient(np.array(losses))).argmin()
180 except ValueError: 182 except ValueError:
@@ -196,8 +198,10 @@ class LRFinder():
196 ax_loss.legend() 198 ax_loss.legend()
197 199
198 200
199def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): 201def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1):
200 def lr_lambda(current_epoch: int): 202 def lr_lambda(base_lr: float, current_epoch: int):
201 return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) 203 return (end_lr / base_lr) ** (current_epoch / num_epochs)
204
205 lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups]
202 206
203 return LambdaLR(optimizer, lr_lambda, last_epoch) 207 return LambdaLR(optimizer, lr_lambdas, last_epoch)