diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 18:55:41 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 18:55:41 +0100 |
commit | 181d56a0af567309a6fda4bfc4e2243ad5f4ca06 (patch) | |
tree | a547bbce7e4b8a8c888109e017a7f7b187dc0eff | |
parent | Update (diff) | |
download | textual-inversion-diff-181d56a0af567309a6fda4bfc4e2243ad5f4ca06.tar.gz textual-inversion-diff-181d56a0af567309a6fda4bfc4e2243ad5f4ca06.tar.bz2 textual-inversion-diff-181d56a0af567309a6fda4bfc4e2243ad5f4ca06.zip |
Fix LR finder
-rw-r--r-- | training/lr.py | 30 |
1 files changed, 23 insertions, 7 deletions
diff --git a/training/lr.py b/training/lr.py index 3cdf994..c765150 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -51,7 +51,7 @@ class LRFinder(): | |||
51 | num_train_batches: int = 1, | 51 | num_train_batches: int = 1, |
52 | num_val_batches: int = math.inf, | 52 | num_val_batches: int = math.inf, |
53 | smooth_f: float = 0.05, | 53 | smooth_f: float = 0.05, |
54 | diverge_th: int = 5 | 54 | diverge_th: int = 5, |
55 | ): | 55 | ): |
56 | best_loss = None | 56 | best_loss = None |
57 | best_acc = None | 57 | best_acc = None |
@@ -157,10 +157,6 @@ class LRFinder(): | |||
157 | # self.model.load_state_dict(self.model_state) | 157 | # self.model.load_state_dict(self.model_state) |
158 | # self.optimizer.load_state_dict(self.optimizer_state) | 158 | # self.optimizer.load_state_dict(self.optimizer_state) |
159 | 159 | ||
160 | if loss > diverge_th * best_loss: | ||
161 | print("Stopping early, the loss has diverged") | ||
162 | break | ||
163 | |||
164 | fig, ax_loss = plt.subplots() | 160 | fig, ax_loss = plt.subplots() |
165 | ax_acc = ax_loss.twinx() | 161 | ax_acc = ax_loss.twinx() |
166 | 162 | ||
@@ -186,14 +182,21 @@ class LRFinder(): | |||
186 | accs = accs[skip_start:-skip_end] | 182 | accs = accs[skip_start:-skip_end] |
187 | 183 | ||
188 | try: | 184 | try: |
189 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | 185 | min_grad_idx = np.gradient(np.array(losses)).argmin() |
186 | except ValueError: | ||
187 | print( | ||
188 | "Failed to compute the gradients, there might not be enough points." | ||
189 | ) | ||
190 | |||
191 | try: | ||
192 | max_val_idx = np.array(accs).argmax() | ||
190 | except ValueError: | 193 | except ValueError: |
191 | print( | 194 | print( |
192 | "Failed to compute the gradients, there might not be enough points." | 195 | "Failed to compute the gradients, there might not be enough points." |
193 | ) | 196 | ) |
194 | 197 | ||
195 | if min_grad_idx is not None: | 198 | if min_grad_idx is not None: |
196 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) | 199 | print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) |
197 | ax_loss.scatter( | 200 | ax_loss.scatter( |
198 | lrs[min_grad_idx], | 201 | lrs[min_grad_idx], |
199 | losses[min_grad_idx], | 202 | losses[min_grad_idx], |
@@ -205,6 +208,19 @@ class LRFinder(): | |||
205 | ) | 208 | ) |
206 | ax_loss.legend() | 209 | ax_loss.legend() |
207 | 210 | ||
211 | if max_val_idx is not None: | ||
212 | print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) | ||
213 | ax_acc.scatter( | ||
214 | lrs[max_val_idx], | ||
215 | accs[max_val_idx], | ||
216 | s=75, | ||
217 | marker="o", | ||
218 | color="blue", | ||
219 | zorder=3, | ||
220 | label="maximum", | ||
221 | ) | ||
222 | ax_acc.legend() | ||
223 | |||
208 | 224 | ||
209 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): | 225 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): |
210 | def lr_lambda(base_lr: float, current_epoch: int): | 226 | def lr_lambda(base_lr: float, current_epoch: int): |