diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/lr.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/training/lr.py b/training/lr.py index 68e0f72..dfb1743 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -48,7 +48,7 @@ class LRFinder(): | |||
| 48 | skip_start: int = 10, | 48 | skip_start: int = 10, |
| 49 | skip_end: int = 5, | 49 | skip_end: int = 5, |
| 50 | num_epochs: int = 100, | 50 | num_epochs: int = 100, |
| 51 | num_train_batches: int = 1, | 51 | num_train_batches: int = math.inf, |
| 52 | num_val_batches: int = math.inf, | 52 | num_val_batches: int = math.inf, |
| 53 | smooth_f: float = 0.05, | 53 | smooth_f: float = 0.05, |
| 54 | ): | 54 | ): |
| @@ -156,6 +156,15 @@ class LRFinder(): | |||
| 156 | # self.model.load_state_dict(self.model_state) | 156 | # self.model.load_state_dict(self.model_state) |
| 157 | # self.optimizer.load_state_dict(self.optimizer_state) | 157 | # self.optimizer.load_state_dict(self.optimizer_state) |
| 158 | 158 | ||
| 159 | if skip_end == 0: | ||
| 160 | lrs = lrs[skip_start:] | ||
| 161 | losses = losses[skip_start:] | ||
| 162 | accs = accs[skip_start:] | ||
| 163 | else: | ||
| 164 | lrs = lrs[skip_start:-skip_end] | ||
| 165 | losses = losses[skip_start:-skip_end] | ||
| 166 | accs = accs[skip_start:-skip_end] | ||
| 167 | |||
| 159 | fig, ax_loss = plt.subplots() | 168 | fig, ax_loss = plt.subplots() |
| 160 | ax_acc = ax_loss.twinx() | 169 | ax_acc = ax_loss.twinx() |
| 161 | 170 | ||
| @@ -171,15 +180,6 @@ class LRFinder(): | |||
| 171 | print("LR suggestion: steepest gradient") | 180 | print("LR suggestion: steepest gradient") |
| 172 | min_grad_idx = None | 181 | min_grad_idx = None |
| 173 | 182 | ||
| 174 | if skip_end == 0: | ||
| 175 | lrs = lrs[skip_start:] | ||
| 176 | losses = losses[skip_start:] | ||
| 177 | accs = accs[skip_start:] | ||
| 178 | else: | ||
| 179 | lrs = lrs[skip_start:-skip_end] | ||
| 180 | losses = losses[skip_start:-skip_end] | ||
| 181 | accs = accs[skip_start:-skip_end] | ||
| 182 | |||
| 183 | try: | 183 | try: |
| 184 | min_grad_idx = np.gradient(np.array(losses)).argmin() | 184 | min_grad_idx = np.gradient(np.array(losses)).argmin() |
| 185 | except ValueError: | 185 | except ValueError: |
