summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 18:55:41 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 18:55:41 +0100
commit181d56a0af567309a6fda4bfc4e2243ad5f4ca06 (patch)
treea547bbce7e4b8a8c888109e017a7f7b187dc0eff /training/lr.py
parentUpdate (diff)
downloadtextual-inversion-diff-181d56a0af567309a6fda4bfc4e2243ad5f4ca06.tar.gz
textual-inversion-diff-181d56a0af567309a6fda4bfc4e2243ad5f4ca06.tar.bz2
textual-inversion-diff-181d56a0af567309a6fda4bfc4e2243ad5f4ca06.zip
Fix LR finder
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py30
1 files changed, 23 insertions, 7 deletions
diff --git a/training/lr.py b/training/lr.py
index 3cdf994..c765150 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -51,7 +51,7 @@ class LRFinder():
51 num_train_batches: int = 1, 51 num_train_batches: int = 1,
52 num_val_batches: int = math.inf, 52 num_val_batches: int = math.inf,
53 smooth_f: float = 0.05, 53 smooth_f: float = 0.05,
54 diverge_th: int = 5 54 diverge_th: int = 5,
55 ): 55 ):
56 best_loss = None 56 best_loss = None
57 best_acc = None 57 best_acc = None
@@ -157,10 +157,6 @@ class LRFinder():
157 # self.model.load_state_dict(self.model_state) 157 # self.model.load_state_dict(self.model_state)
158 # self.optimizer.load_state_dict(self.optimizer_state) 158 # self.optimizer.load_state_dict(self.optimizer_state)
159 159
160 if loss > diverge_th * best_loss:
161 print("Stopping early, the loss has diverged")
162 break
163
164 fig, ax_loss = plt.subplots() 160 fig, ax_loss = plt.subplots()
165 ax_acc = ax_loss.twinx() 161 ax_acc = ax_loss.twinx()
166 162
@@ -186,14 +182,21 @@ class LRFinder():
186 accs = accs[skip_start:-skip_end] 182 accs = accs[skip_start:-skip_end]
187 183
188 try: 184 try:
189 min_grad_idx = (np.gradient(np.array(losses))).argmin() 185 min_grad_idx = np.gradient(np.array(losses)).argmin()
186 except ValueError:
187 print(
188 "Failed to compute the gradients, there might not be enough points."
189 )
190
191 try:
192 max_val_idx = np.array(accs).argmax()
190 except ValueError: 193 except ValueError:
191 print( 194 print(
192 "Failed to compute the gradients, there might not be enough points." 195 "Failed to compute the gradients, there might not be enough points."
193 ) 196 )
194 197
195 if min_grad_idx is not None: 198 if min_grad_idx is not None:
196 print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) 199 print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx]))
197 ax_loss.scatter( 200 ax_loss.scatter(
198 lrs[min_grad_idx], 201 lrs[min_grad_idx],
199 losses[min_grad_idx], 202 losses[min_grad_idx],
@@ -205,6 +208,19 @@ class LRFinder():
205 ) 208 )
206 ax_loss.legend() 209 ax_loss.legend()
207 210
211 if max_val_idx is not None:
212 print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx]))
213 ax_acc.scatter(
214 lrs[max_val_idx],
215 accs[max_val_idx],
216 s=75,
217 marker="o",
218 color="blue",
219 zorder=3,
220 label="maximum",
221 )
222 ax_acc.legend()
223
208 224
209def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): 225def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1):
210 def lr_lambda(base_lr: float, current_epoch: int): 226 def lr_lambda(base_lr: float, current_epoch: int):