diff options
Diffstat (limited to 'training/lr.py')
| -rw-r--r-- | training/lr.py | 256 |
1 files changed, 27 insertions, 229 deletions
diff --git a/training/lr.py b/training/lr.py index 9690738..f5b362f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,238 +1,36 @@ | |||
| 1 | import math | 1 | from pathlib import Path |
| 2 | from contextlib import _GeneratorContextManager, nullcontext | ||
| 3 | from typing import Callable, Any, Tuple, Union | ||
| 4 | from functools import partial | ||
| 5 | 2 | ||
| 6 | import matplotlib.pyplot as plt | 3 | import matplotlib.pyplot as plt |
| 7 | import numpy as np | ||
| 8 | import torch | ||
| 9 | from torch.optim.lr_scheduler import LambdaLR | ||
| 10 | from tqdm.auto import tqdm | ||
| 11 | 4 | ||
| 12 | from training.functional import TrainingCallbacks | ||
| 13 | from training.util import AverageMeter | ||
| 14 | 5 | ||
| 6 | def plot_metrics( | ||
| 7 | metrics: tuple[list[float], list[float], list[float]], | ||
| 8 | output_file: Path, | ||
| 9 | skip_start: int = 10, | ||
| 10 | skip_end: int = 5, | ||
| 11 | ): | ||
| 12 | lrs, losses, accs = metrics | ||
| 15 | 13 | ||
| 16 | def noop(*args, **kwards): | 14 | if skip_end == 0: |
| 17 | pass | 15 | lrs = lrs[skip_start:] |
| 16 | losses = losses[skip_start:] | ||
| 17 | accs = accs[skip_start:] | ||
| 18 | else: | ||
| 19 | lrs = lrs[skip_start:-skip_end] | ||
| 20 | losses = losses[skip_start:-skip_end] | ||
| 21 | accs = accs[skip_start:-skip_end] | ||
| 18 | 22 | ||
| 23 | fig, ax_loss = plt.subplots() | ||
| 24 | ax_acc = ax_loss.twinx() | ||
| 19 | 25 | ||
| 20 | def noop_ctx(*args, **kwards): | 26 | ax_loss.plot(lrs, losses, color='red') |
| 21 | return nullcontext() | 27 | ax_loss.set_xscale("log") |
| 28 | ax_loss.set_xlabel(f"Learning rate") | ||
| 29 | ax_loss.set_ylabel("Loss") | ||
| 22 | 30 | ||
| 31 | ax_acc.plot(lrs, accs, color='blue') | ||
| 32 | ax_acc.set_xscale("log") | ||
| 33 | ax_acc.set_ylabel("Accuracy") | ||
| 23 | 34 | ||
| 24 | class LRFinder(): | 35 | plt.savefig(output_file, dpi=300) |
| 25 | def __init__( | 36 | plt.close() |
| 26 | self, | ||
| 27 | accelerator, | ||
| 28 | optimizer, | ||
| 29 | train_dataloader, | ||
| 30 | val_dataloader, | ||
| 31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | ||
| 32 | callbacks: TrainingCallbacks = TrainingCallbacks() | ||
| 33 | ): | ||
| 34 | self.accelerator = accelerator | ||
| 35 | self.model = callbacks.on_model() | ||
| 36 | self.optimizer = optimizer | ||
| 37 | self.train_dataloader = train_dataloader | ||
| 38 | self.val_dataloader = val_dataloader | ||
| 39 | self.loss_fn = loss_fn | ||
| 40 | self.callbacks = callbacks | ||
| 41 | |||
| 42 | # self.model_state = copy.deepcopy(model.state_dict()) | ||
| 43 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | ||
| 44 | |||
| 45 | def run( | ||
| 46 | self, | ||
| 47 | end_lr, | ||
| 48 | skip_start: int = 10, | ||
| 49 | skip_end: int = 5, | ||
| 50 | num_epochs: int = 100, | ||
| 51 | num_train_batches: int = math.inf, | ||
| 52 | num_val_batches: int = math.inf, | ||
| 53 | smooth_f: float = 0.05, | ||
| 54 | ): | ||
| 55 | best_loss = None | ||
| 56 | best_acc = None | ||
| 57 | |||
| 58 | lrs = [] | ||
| 59 | losses = [] | ||
| 60 | accs = [] | ||
| 61 | |||
| 62 | lr_scheduler = get_exponential_schedule( | ||
| 63 | self.optimizer, | ||
| 64 | end_lr, | ||
| 65 | num_epochs * min(num_train_batches, len(self.train_dataloader)) | ||
| 66 | ) | ||
| 67 | |||
| 68 | steps = min(num_train_batches, len(self.train_dataloader)) | ||
| 69 | steps += min(num_val_batches, len(self.val_dataloader)) | ||
| 70 | steps *= num_epochs | ||
| 71 | |||
| 72 | progress_bar = tqdm( | ||
| 73 | range(steps), | ||
| 74 | disable=not self.accelerator.is_local_main_process, | ||
| 75 | dynamic_ncols=True | ||
| 76 | ) | ||
| 77 | progress_bar.set_description("Epoch X / Y") | ||
| 78 | |||
| 79 | self.callbacks.on_prepare() | ||
| 80 | |||
| 81 | on_train = self.callbacks.on_train | ||
| 82 | on_before_optimize = self.callbacks.on_before_optimize | ||
| 83 | on_after_optimize = self.callbacks.on_after_optimize | ||
| 84 | on_eval = self.callbacks.on_eval | ||
| 85 | |||
| 86 | for epoch in range(num_epochs): | ||
| 87 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
| 88 | |||
| 89 | avg_loss = AverageMeter() | ||
| 90 | avg_acc = AverageMeter() | ||
| 91 | |||
| 92 | self.model.train() | ||
| 93 | |||
| 94 | with on_train(epoch): | ||
| 95 | for step, batch in enumerate(self.train_dataloader): | ||
| 96 | if step >= num_train_batches: | ||
| 97 | break | ||
| 98 | |||
| 99 | with self.accelerator.accumulate(self.model): | ||
| 100 | loss, acc, bsz = self.loss_fn(step, batch) | ||
| 101 | |||
| 102 | self.accelerator.backward(loss) | ||
| 103 | |||
| 104 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | ||
| 105 | |||
| 106 | self.optimizer.step() | ||
| 107 | lr_scheduler.step() | ||
| 108 | self.optimizer.zero_grad(set_to_none=True) | ||
| 109 | |||
| 110 | if self.accelerator.sync_gradients: | ||
| 111 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 112 | |||
| 113 | progress_bar.update(1) | ||
| 114 | |||
| 115 | self.model.eval() | ||
| 116 | |||
| 117 | with torch.inference_mode(): | ||
| 118 | with on_eval(): | ||
| 119 | for step, batch in enumerate(self.val_dataloader): | ||
| 120 | if step >= num_val_batches: | ||
| 121 | break | ||
| 122 | |||
| 123 | loss, acc, bsz = self.loss_fn(step, batch, True) | ||
| 124 | avg_loss.update(loss.detach_(), bsz) | ||
| 125 | avg_acc.update(acc.detach_(), bsz) | ||
| 126 | |||
| 127 | progress_bar.update(1) | ||
| 128 | |||
| 129 | loss = avg_loss.avg.item() | ||
| 130 | acc = avg_acc.avg.item() | ||
| 131 | |||
| 132 | if epoch == 0: | ||
| 133 | best_loss = loss | ||
| 134 | best_acc = acc | ||
| 135 | else: | ||
| 136 | if smooth_f > 0: | ||
| 137 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | ||
| 138 | acc = smooth_f * acc + (1 - smooth_f) * accs[-1] | ||
| 139 | if loss < best_loss: | ||
| 140 | best_loss = loss | ||
| 141 | if acc > best_acc: | ||
| 142 | best_acc = acc | ||
| 143 | |||
| 144 | lr = lr_scheduler.get_last_lr()[0] | ||
| 145 | |||
| 146 | lrs.append(lr) | ||
| 147 | losses.append(loss) | ||
| 148 | accs.append(acc) | ||
| 149 | |||
| 150 | self.accelerator.log({ | ||
| 151 | "loss": loss, | ||
| 152 | "acc": acc, | ||
| 153 | "lr": lr, | ||
| 154 | }, step=epoch) | ||
| 155 | |||
| 156 | progress_bar.set_postfix({ | ||
| 157 | "loss": loss, | ||
| 158 | "loss/best": best_loss, | ||
| 159 | "acc": acc, | ||
| 160 | "acc/best": best_acc, | ||
| 161 | "lr": lr, | ||
| 162 | }) | ||
| 163 | |||
| 164 | # self.model.load_state_dict(self.model_state) | ||
| 165 | # self.optimizer.load_state_dict(self.optimizer_state) | ||
| 166 | |||
| 167 | if skip_end == 0: | ||
| 168 | lrs = lrs[skip_start:] | ||
| 169 | losses = losses[skip_start:] | ||
| 170 | accs = accs[skip_start:] | ||
| 171 | else: | ||
| 172 | lrs = lrs[skip_start:-skip_end] | ||
| 173 | losses = losses[skip_start:-skip_end] | ||
| 174 | accs = accs[skip_start:-skip_end] | ||
| 175 | |||
| 176 | fig, ax_loss = plt.subplots() | ||
| 177 | ax_acc = ax_loss.twinx() | ||
| 178 | |||
| 179 | ax_loss.plot(lrs, losses, color='red') | ||
| 180 | ax_loss.set_xscale("log") | ||
| 181 | ax_loss.set_xlabel(f"Learning rate") | ||
| 182 | ax_loss.set_ylabel("Loss") | ||
| 183 | |||
| 184 | ax_acc.plot(lrs, accs, color='blue') | ||
| 185 | ax_acc.set_xscale("log") | ||
| 186 | ax_acc.set_ylabel("Accuracy") | ||
| 187 | |||
| 188 | print("LR suggestion: steepest gradient") | ||
| 189 | min_grad_idx = None | ||
| 190 | |||
| 191 | try: | ||
| 192 | min_grad_idx = np.gradient(np.array(losses)).argmin() | ||
| 193 | except ValueError: | ||
| 194 | print( | ||
| 195 | "Failed to compute the gradients, there might not be enough points." | ||
| 196 | ) | ||
| 197 | |||
| 198 | try: | ||
| 199 | max_val_idx = np.array(accs).argmax() | ||
| 200 | except ValueError: | ||
| 201 | print( | ||
| 202 | "Failed to compute the gradients, there might not be enough points." | ||
| 203 | ) | ||
| 204 | |||
| 205 | if min_grad_idx is not None: | ||
| 206 | print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) | ||
| 207 | ax_loss.scatter( | ||
| 208 | lrs[min_grad_idx], | ||
| 209 | losses[min_grad_idx], | ||
| 210 | s=75, | ||
| 211 | marker="o", | ||
| 212 | color="red", | ||
| 213 | zorder=3, | ||
| 214 | label="steepest gradient", | ||
| 215 | ) | ||
| 216 | ax_loss.legend() | ||
| 217 | |||
| 218 | if max_val_idx is not None: | ||
| 219 | print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) | ||
| 220 | ax_acc.scatter( | ||
| 221 | lrs[max_val_idx], | ||
| 222 | accs[max_val_idx], | ||
| 223 | s=75, | ||
| 224 | marker="o", | ||
| 225 | color="blue", | ||
| 226 | zorder=3, | ||
| 227 | label="maximum", | ||
| 228 | ) | ||
| 229 | ax_acc.legend() | ||
| 230 | |||
| 231 | |||
| 232 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): | ||
| 233 | def lr_lambda(base_lr: float, current_epoch: int): | ||
| 234 | return (end_lr / base_lr) ** (current_epoch / num_epochs) | ||
| 235 | |||
| 236 | lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] | ||
| 237 | |||
| 238 | return LambdaLR(optimizer, lr_lambdas, last_epoch) | ||
