diff options
| -rw-r--r-- | training/lr.py | 30 |
1 files changed, 23 insertions, 7 deletions
diff --git a/training/lr.py b/training/lr.py index 3cdf994..c765150 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -51,7 +51,7 @@ class LRFinder(): | |||
| 51 | num_train_batches: int = 1, | 51 | num_train_batches: int = 1, |
| 52 | num_val_batches: int = math.inf, | 52 | num_val_batches: int = math.inf, |
| 53 | smooth_f: float = 0.05, | 53 | smooth_f: float = 0.05, |
| 54 | diverge_th: int = 5 | 54 | diverge_th: int = 5, |
| 55 | ): | 55 | ): |
| 56 | best_loss = None | 56 | best_loss = None |
| 57 | best_acc = None | 57 | best_acc = None |
| @@ -157,10 +157,6 @@ class LRFinder(): | |||
| 157 | # self.model.load_state_dict(self.model_state) | 157 | # self.model.load_state_dict(self.model_state) |
| 158 | # self.optimizer.load_state_dict(self.optimizer_state) | 158 | # self.optimizer.load_state_dict(self.optimizer_state) |
| 159 | 159 | ||
| 160 | if loss > diverge_th * best_loss: | ||
| 161 | print("Stopping early, the loss has diverged") | ||
| 162 | break | ||
| 163 | |||
| 164 | fig, ax_loss = plt.subplots() | 160 | fig, ax_loss = plt.subplots() |
| 165 | ax_acc = ax_loss.twinx() | 161 | ax_acc = ax_loss.twinx() |
| 166 | 162 | ||
| @@ -186,14 +182,21 @@ class LRFinder(): | |||
| 186 | accs = accs[skip_start:-skip_end] | 182 | accs = accs[skip_start:-skip_end] |
| 187 | 183 | ||
| 188 | try: | 184 | try: |
| 189 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | 185 | min_grad_idx = np.gradient(np.array(losses)).argmin() |
| 186 | except ValueError: | ||
| 187 | print( | ||
| 188 | "Failed to compute the gradients, there might not be enough points." | ||
| 189 | ) | ||
| 190 | |||
| 191 | try: | ||
| 192 | max_val_idx = np.array(accs).argmax() | ||
| 190 | except ValueError: | 193 | except ValueError: |
| 191 | print( | 194 | print( |
| 192 | "Failed to compute the gradients, there might not be enough points." | 195 | "Failed to compute the gradients, there might not be enough points." |
| 193 | ) | 196 | ) |
| 194 | 197 | ||
| 195 | if min_grad_idx is not None: | 198 | if min_grad_idx is not None: |
| 196 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) | 199 | print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) |
| 197 | ax_loss.scatter( | 200 | ax_loss.scatter( |
| 198 | lrs[min_grad_idx], | 201 | lrs[min_grad_idx], |
| 199 | losses[min_grad_idx], | 202 | losses[min_grad_idx], |
| @@ -205,6 +208,19 @@ class LRFinder(): | |||
| 205 | ) | 208 | ) |
| 206 | ax_loss.legend() | 209 | ax_loss.legend() |
| 207 | 210 | ||
| 211 | if max_val_idx is not None: | ||
| 212 | print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) | ||
| 213 | ax_acc.scatter( | ||
| 214 | lrs[max_val_idx], | ||
| 215 | accs[max_val_idx], | ||
| 216 | s=75, | ||
| 217 | marker="o", | ||
| 218 | color="blue", | ||
| 219 | zorder=3, | ||
| 220 | label="maximum", | ||
| 221 | ) | ||
| 222 | ax_acc.legend() | ||
| 223 | |||
| 208 | 224 | ||
| 209 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): | 225 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): |
| 210 | def lr_lambda(base_lr: float, current_epoch: int): | 226 | def lr_lambda(base_lr: float, current_epoch: int): |
