diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/lr.py | 46 | ||||
| -rw-r--r-- | training/util.py | 5 |
2 files changed, 36 insertions, 15 deletions
diff --git a/training/lr.py b/training/lr.py index 8e558e1..c1fa3a0 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -22,10 +22,13 @@ class LRFinder(): | |||
| 22 | self.model_state = copy.deepcopy(model.state_dict()) | 22 | self.model_state = copy.deepcopy(model.state_dict()) |
| 23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
| 24 | 24 | ||
| 25 | def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): |
| 26 | best_loss = None | 26 | best_loss = None |
| 27 | best_acc = None | ||
| 28 | |||
| 27 | lrs = [] | 29 | lrs = [] |
| 28 | losses = [] | 30 | losses = [] |
| 31 | accs = [] | ||
| 29 | 32 | ||
| 30 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) | 33 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) |
| 31 | 34 | ||
| @@ -44,6 +47,7 @@ class LRFinder(): | |||
| 44 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 47 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 45 | 48 | ||
| 46 | avg_loss = AverageMeter() | 49 | avg_loss = AverageMeter() |
| 50 | avg_acc = AverageMeter() | ||
| 47 | 51 | ||
| 48 | self.model.train() | 52 | self.model.train() |
| 49 | 53 | ||
| @@ -71,28 +75,37 @@ class LRFinder(): | |||
| 71 | 75 | ||
| 72 | loss, acc, bsz = self.loss_fn(batch) | 76 | loss, acc, bsz = self.loss_fn(batch) |
| 73 | avg_loss.update(loss.detach_(), bsz) | 77 | avg_loss.update(loss.detach_(), bsz) |
| 78 | avg_acc.update(acc.detach_(), bsz) | ||
| 74 | 79 | ||
| 75 | progress_bar.update(1) | 80 | progress_bar.update(1) |
| 76 | 81 | ||
| 77 | lr_scheduler.step() | 82 | lr_scheduler.step() |
| 78 | 83 | ||
| 79 | loss = avg_loss.avg.item() | 84 | loss = avg_loss.avg.item() |
| 85 | acc = avg_acc.avg.item() | ||
| 86 | |||
| 80 | if epoch == 0: | 87 | if epoch == 0: |
| 81 | best_loss = loss | 88 | best_loss = loss |
| 89 | best_acc = acc | ||
| 82 | else: | 90 | else: |
| 83 | if smooth_f > 0: | 91 | if smooth_f > 0: |
| 84 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | 92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] |
| 85 | if loss < best_loss: | 93 | if loss < best_loss: |
| 86 | best_loss = loss | 94 | best_loss = loss |
| 95 | if acc > best_acc: | ||
| 96 | best_acc = acc | ||
| 87 | 97 | ||
| 88 | lr = lr_scheduler.get_last_lr()[0] | 98 | lr = lr_scheduler.get_last_lr()[0] |
| 89 | 99 | ||
| 90 | lrs.append(lr) | 100 | lrs.append(lr) |
| 91 | losses.append(loss) | 101 | losses.append(loss) |
| 102 | accs.append(acc) | ||
| 92 | 103 | ||
| 93 | progress_bar.set_postfix({ | 104 | progress_bar.set_postfix({ |
| 94 | "loss": loss, | 105 | "loss": loss, |
| 95 | "best": best_loss, | 106 | "loss/best": best_loss, |
| 107 | "acc": acc, | ||
| 108 | "acc/best": best_acc, | ||
| 96 | "lr": lr, | 109 | "lr": lr, |
| 97 | }) | 110 | }) |
| 98 | 111 | ||
| @@ -103,20 +116,37 @@ class LRFinder(): | |||
| 103 | print("Stopping early, the loss has diverged") | 116 | print("Stopping early, the loss has diverged") |
| 104 | break | 117 | break |
| 105 | 118 | ||
| 106 | fig, ax = plt.subplots() | 119 | if skip_end == 0: |
| 107 | ax.plot(lrs, losses) | 120 | lrs = lrs[skip_start:] |
| 121 | losses = losses[skip_start:] | ||
| 122 | accs = accs[skip_start:] | ||
| 123 | else: | ||
| 124 | lrs = lrs[skip_start:-skip_end] | ||
| 125 | losses = losses[skip_start:-skip_end] | ||
| 126 | accs = accs[skip_start:-skip_end] | ||
| 127 | |||
| 128 | fig, ax_loss = plt.subplots() | ||
| 129 | |||
| 130 | ax_loss.plot(lrs, losses, color='red', label='Loss') | ||
| 131 | ax_loss.set_xscale("log") | ||
| 132 | ax_loss.set_xlabel("Learning rate") | ||
| 133 | |||
| 134 | # ax_acc = ax_loss.twinx() | ||
| 135 | # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') | ||
| 108 | 136 | ||
| 109 | print("LR suggestion: steepest gradient") | 137 | print("LR suggestion: steepest gradient") |
| 110 | min_grad_idx = None | 138 | min_grad_idx = None |
| 139 | |||
| 111 | try: | 140 | try: |
| 112 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | 141 | min_grad_idx = (np.gradient(np.array(losses))).argmin() |
| 113 | except ValueError: | 142 | except ValueError: |
| 114 | print( | 143 | print( |
| 115 | "Failed to compute the gradients, there might not be enough points." | 144 | "Failed to compute the gradients, there might not be enough points." |
| 116 | ) | 145 | ) |
| 146 | |||
| 117 | if min_grad_idx is not None: | 147 | if min_grad_idx is not None: |
| 118 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) | 148 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) |
| 119 | ax.scatter( | 149 | ax_loss.scatter( |
| 120 | lrs[min_grad_idx], | 150 | lrs[min_grad_idx], |
| 121 | losses[min_grad_idx], | 151 | losses[min_grad_idx], |
| 122 | s=75, | 152 | s=75, |
| @@ -125,11 +155,7 @@ class LRFinder(): | |||
| 125 | zorder=3, | 155 | zorder=3, |
| 126 | label="steepest gradient", | 156 | label="steepest gradient", |
| 127 | ) | 157 | ) |
| 128 | ax.legend() | 158 | ax_loss.legend() |
| 129 | |||
| 130 | ax.set_xscale("log") | ||
| 131 | ax.set_xlabel("Learning rate") | ||
| 132 | ax.set_ylabel("Loss") | ||
| 133 | 159 | ||
| 134 | 160 | ||
| 135 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): | 161 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): |
diff --git a/training/util.py b/training/util.py index a0c15cd..d0f7fcd 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -5,11 +5,6 @@ import torch | |||
| 5 | from PIL import Image | 5 | from PIL import Image |
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | def freeze_params(params): | ||
| 9 | for param in params: | ||
| 10 | param.requires_grad = False | ||
| 11 | |||
| 12 | |||
| 13 | def save_args(basepath: Path, args, extra={}): | 8 | def save_args(basepath: Path, args, extra={}): |
| 14 | info = {"args": vars(args)} | 9 | info = {"args": vars(args)} |
| 15 | info["args"].update(extra) | 10 | info["args"].update(extra) |
