diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/lr.py | 38 |
1 files changed, 21 insertions, 17 deletions
diff --git a/training/lr.py b/training/lr.py index fe166ed..acc01a2 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import copy | 2 | import copy |
| 3 | from typing import Callable | 3 | from typing import Callable |
| 4 | from functools import partial | ||
| 4 | 5 | ||
| 5 | import matplotlib.pyplot as plt | 6 | import matplotlib.pyplot as plt |
| 6 | import numpy as np | 7 | import numpy as np |
| @@ -41,7 +42,7 @@ class LRFinder(): | |||
| 41 | 42 | ||
| 42 | def run( | 43 | def run( |
| 43 | self, | 44 | self, |
| 44 | min_lr, | 45 | end_lr, |
| 45 | skip_start: int = 10, | 46 | skip_start: int = 10, |
| 46 | skip_end: int = 5, | 47 | skip_end: int = 5, |
| 47 | num_epochs: int = 100, | 48 | num_epochs: int = 100, |
| @@ -57,7 +58,7 @@ class LRFinder(): | |||
| 57 | losses = [] | 58 | losses = [] |
| 58 | accs = [] | 59 | accs = [] |
| 59 | 60 | ||
| 60 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) | 61 | lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) |
| 61 | 62 | ||
| 62 | steps = min(num_train_batches, len(self.train_dataloader)) | 63 | steps = min(num_train_batches, len(self.train_dataloader)) |
| 63 | steps += min(num_val_batches, len(self.val_dataloader)) | 64 | steps += min(num_val_batches, len(self.val_dataloader)) |
| @@ -152,29 +153,30 @@ class LRFinder(): | |||
| 152 | print("Stopping early, the loss has diverged") | 153 | print("Stopping early, the loss has diverged") |
| 153 | break | 154 | break |
| 154 | 155 | ||
| 155 | if skip_end == 0: | ||
| 156 | lrs = lrs[skip_start:] | ||
| 157 | losses = losses[skip_start:] | ||
| 158 | accs = accs[skip_start:] | ||
| 159 | else: | ||
| 160 | lrs = lrs[skip_start:-skip_end] | ||
| 161 | losses = losses[skip_start:-skip_end] | ||
| 162 | accs = accs[skip_start:-skip_end] | ||
| 163 | |||
| 164 | fig, ax_loss = plt.subplots() | 156 | fig, ax_loss = plt.subplots() |
| 157 | ax_acc = ax_loss.twinx() | ||
| 165 | 158 | ||
| 166 | ax_loss.plot(lrs, losses, color='red') | 159 | ax_loss.plot(lrs, losses, color='red') |
| 167 | ax_loss.set_xscale("log") | 160 | ax_loss.set_xscale("log") |
| 168 | ax_loss.set_xlabel("Learning rate") | 161 | ax_loss.set_xlabel(f"Learning rate") |
| 169 | ax_loss.set_ylabel("Loss") | 162 | ax_loss.set_ylabel("Loss") |
| 170 | 163 | ||
| 171 | ax_acc = ax_loss.twinx() | ||
| 172 | ax_acc.plot(lrs, accs, color='blue') | 164 | ax_acc.plot(lrs, accs, color='blue') |
| 165 | ax_acc.set_xscale("log") | ||
| 173 | ax_acc.set_ylabel("Accuracy") | 166 | ax_acc.set_ylabel("Accuracy") |
| 174 | 167 | ||
| 175 | print("LR suggestion: steepest gradient") | 168 | print("LR suggestion: steepest gradient") |
| 176 | min_grad_idx = None | 169 | min_grad_idx = None |
| 177 | 170 | ||
| 171 | if skip_end == 0: | ||
| 172 | lrs = lrs[skip_start:] | ||
| 173 | losses = losses[skip_start:] | ||
| 174 | accs = accs[skip_start:] | ||
| 175 | else: | ||
| 176 | lrs = lrs[skip_start:-skip_end] | ||
| 177 | losses = losses[skip_start:-skip_end] | ||
| 178 | accs = accs[skip_start:-skip_end] | ||
| 179 | |||
| 178 | try: | 180 | try: |
| 179 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | 181 | min_grad_idx = (np.gradient(np.array(losses))).argmin() |
| 180 | except ValueError: | 182 | except ValueError: |
| @@ -196,8 +198,10 @@ class LRFinder(): | |||
| 196 | ax_loss.legend() | 198 | ax_loss.legend() |
| 197 | 199 | ||
| 198 | 200 | ||
| 199 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): | 201 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): |
| 200 | def lr_lambda(current_epoch: int): | 202 | def lr_lambda(base_lr: float, current_epoch: int): |
| 201 | return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) | 203 | return (end_lr / base_lr) ** (current_epoch / num_epochs) |
| 204 | |||
| 205 | lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] | ||
| 202 | 206 | ||
| 203 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 207 | return LambdaLR(optimizer, lr_lambdas, last_epoch) |
