diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/training/lr.py b/training/lr.py index 68e0f72..dfb1743 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -48,7 +48,7 @@ class LRFinder(): | |||
48 | skip_start: int = 10, | 48 | skip_start: int = 10, |
49 | skip_end: int = 5, | 49 | skip_end: int = 5, |
50 | num_epochs: int = 100, | 50 | num_epochs: int = 100, |
51 | num_train_batches: int = 1, | 51 | num_train_batches: int = math.inf, |
52 | num_val_batches: int = math.inf, | 52 | num_val_batches: int = math.inf, |
53 | smooth_f: float = 0.05, | 53 | smooth_f: float = 0.05, |
54 | ): | 54 | ): |
@@ -156,6 +156,15 @@ class LRFinder(): | |||
156 | # self.model.load_state_dict(self.model_state) | 156 | # self.model.load_state_dict(self.model_state) |
157 | # self.optimizer.load_state_dict(self.optimizer_state) | 157 | # self.optimizer.load_state_dict(self.optimizer_state) |
158 | 158 | ||
159 | if skip_end == 0: | ||
160 | lrs = lrs[skip_start:] | ||
161 | losses = losses[skip_start:] | ||
162 | accs = accs[skip_start:] | ||
163 | else: | ||
164 | lrs = lrs[skip_start:-skip_end] | ||
165 | losses = losses[skip_start:-skip_end] | ||
166 | accs = accs[skip_start:-skip_end] | ||
167 | |||
159 | fig, ax_loss = plt.subplots() | 168 | fig, ax_loss = plt.subplots() |
160 | ax_acc = ax_loss.twinx() | 169 | ax_acc = ax_loss.twinx() |
161 | 170 | ||
@@ -171,15 +180,6 @@ class LRFinder(): | |||
171 | print("LR suggestion: steepest gradient") | 180 | print("LR suggestion: steepest gradient") |
172 | min_grad_idx = None | 181 | min_grad_idx = None |
173 | 182 | ||
174 | if skip_end == 0: | ||
175 | lrs = lrs[skip_start:] | ||
176 | losses = losses[skip_start:] | ||
177 | accs = accs[skip_start:] | ||
178 | else: | ||
179 | lrs = lrs[skip_start:-skip_end] | ||
180 | losses = losses[skip_start:-skip_end] | ||
181 | accs = accs[skip_start:-skip_end] | ||
182 | |||
183 | try: | 183 | try: |
184 | min_grad_idx = np.gradient(np.array(losses)).argmin() | 184 | min_grad_idx = np.gradient(np.array(losses)).argmin() |
185 | except ValueError: | 185 | except ValueError: |