diff options
Diffstat (limited to 'training/lr.py')
-rw-r--r-- | training/lr.py | 266 |
1 files changed, 32 insertions, 234 deletions
diff --git a/training/lr.py b/training/lr.py index 9690738..f5b362f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -1,238 +1,36 @@ | |||
1 | import math | 1 | from pathlib import Path |
2 | from contextlib import _GeneratorContextManager, nullcontext | ||
3 | from typing import Callable, Any, Tuple, Union | ||
4 | from functools import partial | ||
5 | 2 | ||
6 | import matplotlib.pyplot as plt | 3 | import matplotlib.pyplot as plt |
7 | import numpy as np | ||
8 | import torch | ||
9 | from torch.optim.lr_scheduler import LambdaLR | ||
10 | from tqdm.auto import tqdm | ||
11 | 4 | ||
12 | from training.functional import TrainingCallbacks | ||
13 | from training.util import AverageMeter | ||
14 | 5 | ||
15 | 6 | def plot_metrics( | |
16 | def noop(*args, **kwards): | 7 | metrics: tuple[list[float], list[float], list[float]], |
17 | pass | 8 | output_file: Path, |
18 | 9 | skip_start: int = 10, | |
19 | 10 | skip_end: int = 5, | |
20 | def noop_ctx(*args, **kwards): | 11 | ): |
21 | return nullcontext() | 12 | lrs, losses, accs = metrics |
22 | 13 | ||
23 | 14 | if skip_end == 0: | |
24 | class LRFinder(): | 15 | lrs = lrs[skip_start:] |
25 | def __init__( | 16 | losses = losses[skip_start:] |
26 | self, | 17 | accs = accs[skip_start:] |
27 | accelerator, | 18 | else: |
28 | optimizer, | 19 | lrs = lrs[skip_start:-skip_end] |
29 | train_dataloader, | 20 | losses = losses[skip_start:-skip_end] |
30 | val_dataloader, | 21 | accs = accs[skip_start:-skip_end] |
31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 22 | |
32 | callbacks: TrainingCallbacks = TrainingCallbacks() | 23 | fig, ax_loss = plt.subplots() |
33 | ): | 24 | ax_acc = ax_loss.twinx() |
34 | self.accelerator = accelerator | 25 | |
35 | self.model = callbacks.on_model() | 26 | ax_loss.plot(lrs, losses, color='red') |
36 | self.optimizer = optimizer | 27 | ax_loss.set_xscale("log") |
37 | self.train_dataloader = train_dataloader | 28 | ax_loss.set_xlabel(f"Learning rate") |
38 | self.val_dataloader = val_dataloader | 29 | ax_loss.set_ylabel("Loss") |
39 | self.loss_fn = loss_fn | 30 | |
40 | self.callbacks = callbacks | 31 | ax_acc.plot(lrs, accs, color='blue') |
41 | 32 | ax_acc.set_xscale("log") | |
42 | # self.model_state = copy.deepcopy(model.state_dict()) | 33 | ax_acc.set_ylabel("Accuracy") |
43 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 34 | |
44 | 35 | plt.savefig(output_file, dpi=300) | |
45 | def run( | 36 | plt.close() |
46 | self, | ||
47 | end_lr, | ||
48 | skip_start: int = 10, | ||
49 | skip_end: int = 5, | ||
50 | num_epochs: int = 100, | ||
51 | num_train_batches: int = math.inf, | ||
52 | num_val_batches: int = math.inf, | ||
53 | smooth_f: float = 0.05, | ||
54 | ): | ||
55 | best_loss = None | ||
56 | best_acc = None | ||
57 | |||
58 | lrs = [] | ||
59 | losses = [] | ||
60 | accs = [] | ||
61 | |||
62 | lr_scheduler = get_exponential_schedule( | ||
63 | self.optimizer, | ||
64 | end_lr, | ||
65 | num_epochs * min(num_train_batches, len(self.train_dataloader)) | ||
66 | ) | ||
67 | |||
68 | steps = min(num_train_batches, len(self.train_dataloader)) | ||
69 | steps += min(num_val_batches, len(self.val_dataloader)) | ||
70 | steps *= num_epochs | ||
71 | |||
72 | progress_bar = tqdm( | ||
73 | range(steps), | ||
74 | disable=not self.accelerator.is_local_main_process, | ||
75 | dynamic_ncols=True | ||
76 | ) | ||
77 | progress_bar.set_description("Epoch X / Y") | ||
78 | |||
79 | self.callbacks.on_prepare() | ||
80 | |||
81 | on_train = self.callbacks.on_train | ||
82 | on_before_optimize = self.callbacks.on_before_optimize | ||
83 | on_after_optimize = self.callbacks.on_after_optimize | ||
84 | on_eval = self.callbacks.on_eval | ||
85 | |||
86 | for epoch in range(num_epochs): | ||
87 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
88 | |||
89 | avg_loss = AverageMeter() | ||
90 | avg_acc = AverageMeter() | ||
91 | |||
92 | self.model.train() | ||
93 | |||
94 | with on_train(epoch): | ||
95 | for step, batch in enumerate(self.train_dataloader): | ||
96 | if step >= num_train_batches: | ||
97 | break | ||
98 | |||
99 | with self.accelerator.accumulate(self.model): | ||
100 | loss, acc, bsz = self.loss_fn(step, batch) | ||
101 | |||
102 | self.accelerator.backward(loss) | ||
103 | |||
104 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | ||
105 | |||
106 | self.optimizer.step() | ||
107 | lr_scheduler.step() | ||
108 | self.optimizer.zero_grad(set_to_none=True) | ||
109 | |||
110 | if self.accelerator.sync_gradients: | ||
111 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
112 | |||
113 | progress_bar.update(1) | ||
114 | |||
115 | self.model.eval() | ||
116 | |||
117 | with torch.inference_mode(): | ||
118 | with on_eval(): | ||
119 | for step, batch in enumerate(self.val_dataloader): | ||
120 | if step >= num_val_batches: | ||
121 | break | ||
122 | |||
123 | loss, acc, bsz = self.loss_fn(step, batch, True) | ||
124 | avg_loss.update(loss.detach_(), bsz) | ||
125 | avg_acc.update(acc.detach_(), bsz) | ||
126 | |||
127 | progress_bar.update(1) | ||
128 | |||
129 | loss = avg_loss.avg.item() | ||
130 | acc = avg_acc.avg.item() | ||
131 | |||
132 | if epoch == 0: | ||
133 | best_loss = loss | ||
134 | best_acc = acc | ||
135 | else: | ||
136 | if smooth_f > 0: | ||
137 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | ||
138 | acc = smooth_f * acc + (1 - smooth_f) * accs[-1] | ||
139 | if loss < best_loss: | ||
140 | best_loss = loss | ||
141 | if acc > best_acc: | ||
142 | best_acc = acc | ||
143 | |||
144 | lr = lr_scheduler.get_last_lr()[0] | ||
145 | |||
146 | lrs.append(lr) | ||
147 | losses.append(loss) | ||
148 | accs.append(acc) | ||
149 | |||
150 | self.accelerator.log({ | ||
151 | "loss": loss, | ||
152 | "acc": acc, | ||
153 | "lr": lr, | ||
154 | }, step=epoch) | ||
155 | |||
156 | progress_bar.set_postfix({ | ||
157 | "loss": loss, | ||
158 | "loss/best": best_loss, | ||
159 | "acc": acc, | ||
160 | "acc/best": best_acc, | ||
161 | "lr": lr, | ||
162 | }) | ||
163 | |||
164 | # self.model.load_state_dict(self.model_state) | ||
165 | # self.optimizer.load_state_dict(self.optimizer_state) | ||
166 | |||
167 | if skip_end == 0: | ||
168 | lrs = lrs[skip_start:] | ||
169 | losses = losses[skip_start:] | ||
170 | accs = accs[skip_start:] | ||
171 | else: | ||
172 | lrs = lrs[skip_start:-skip_end] | ||
173 | losses = losses[skip_start:-skip_end] | ||
174 | accs = accs[skip_start:-skip_end] | ||
175 | |||
176 | fig, ax_loss = plt.subplots() | ||
177 | ax_acc = ax_loss.twinx() | ||
178 | |||
179 | ax_loss.plot(lrs, losses, color='red') | ||
180 | ax_loss.set_xscale("log") | ||
181 | ax_loss.set_xlabel(f"Learning rate") | ||
182 | ax_loss.set_ylabel("Loss") | ||
183 | |||
184 | ax_acc.plot(lrs, accs, color='blue') | ||
185 | ax_acc.set_xscale("log") | ||
186 | ax_acc.set_ylabel("Accuracy") | ||
187 | |||
188 | print("LR suggestion: steepest gradient") | ||
189 | min_grad_idx = None | ||
190 | |||
191 | try: | ||
192 | min_grad_idx = np.gradient(np.array(losses)).argmin() | ||
193 | except ValueError: | ||
194 | print( | ||
195 | "Failed to compute the gradients, there might not be enough points." | ||
196 | ) | ||
197 | |||
198 | try: | ||
199 | max_val_idx = np.array(accs).argmax() | ||
200 | except ValueError: | ||
201 | print( | ||
202 | "Failed to compute the gradients, there might not be enough points." | ||
203 | ) | ||
204 | |||
205 | if min_grad_idx is not None: | ||
206 | print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) | ||
207 | ax_loss.scatter( | ||
208 | lrs[min_grad_idx], | ||
209 | losses[min_grad_idx], | ||
210 | s=75, | ||
211 | marker="o", | ||
212 | color="red", | ||
213 | zorder=3, | ||
214 | label="steepest gradient", | ||
215 | ) | ||
216 | ax_loss.legend() | ||
217 | |||
218 | if max_val_idx is not None: | ||
219 | print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) | ||
220 | ax_acc.scatter( | ||
221 | lrs[max_val_idx], | ||
222 | accs[max_val_idx], | ||
223 | s=75, | ||
224 | marker="o", | ||
225 | color="blue", | ||
226 | zorder=3, | ||
227 | label="maximum", | ||
228 | ) | ||
229 | ax_acc.legend() | ||
230 | |||
231 | |||
232 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): | ||
233 | def lr_lambda(base_lr: float, current_epoch: int): | ||
234 | return (end_lr / base_lr) ** (current_epoch / num_epochs) | ||
235 | |||
236 | lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] | ||
237 | |||
238 | return LambdaLR(optimizer, lr_lambdas, last_epoch) | ||