summaryrefslogtreecommitdiffstats
path: root/training/lr.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-20 14:26:17 +0100
committerVolpeon <git@volpeon.ink>2023-01-20 14:26:17 +0100
commit3575d041f1507811b577fd2c653171fb51c0a386 (patch)
tree702f9f1ae4eafc6f8ea06560c4de6bbe1c2acecb /training/lr.py
parentMove Accelerator preparation into strategy (diff)
downloadtextual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.gz
textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.bz2
textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.zip
Restored LR finder
Diffstat (limited to 'training/lr.py')
-rw-r--r--training/lr.py266
1 files changed, 32 insertions, 234 deletions
diff --git a/training/lr.py b/training/lr.py
index 9690738..f5b362f 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,238 +1,36 @@
1import math 1from pathlib import Path
2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union
4from functools import partial
5 2
6import matplotlib.pyplot as plt 3import matplotlib.pyplot as plt
7import numpy as np
8import torch
9from torch.optim.lr_scheduler import LambdaLR
10from tqdm.auto import tqdm
11 4
12from training.functional import TrainingCallbacks
13from training.util import AverageMeter
14 5
15 6def plot_metrics(
16def noop(*args, **kwards): 7 metrics: tuple[list[float], list[float], list[float]],
17 pass 8 output_file: Path,
18 9 skip_start: int = 10,
19 10 skip_end: int = 5,
20def noop_ctx(*args, **kwards): 11):
21 return nullcontext() 12 lrs, losses, accs = metrics
22 13
23 14 if skip_end == 0:
24class LRFinder(): 15 lrs = lrs[skip_start:]
25 def __init__( 16 losses = losses[skip_start:]
26 self, 17 accs = accs[skip_start:]
27 accelerator, 18 else:
28 optimizer, 19 lrs = lrs[skip_start:-skip_end]
29 train_dataloader, 20 losses = losses[skip_start:-skip_end]
30 val_dataloader, 21 accs = accs[skip_start:-skip_end]
31 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 22
32 callbacks: TrainingCallbacks = TrainingCallbacks() 23 fig, ax_loss = plt.subplots()
33 ): 24 ax_acc = ax_loss.twinx()
34 self.accelerator = accelerator 25
35 self.model = callbacks.on_model() 26 ax_loss.plot(lrs, losses, color='red')
36 self.optimizer = optimizer 27 ax_loss.set_xscale("log")
37 self.train_dataloader = train_dataloader 28 ax_loss.set_xlabel(f"Learning rate")
38 self.val_dataloader = val_dataloader 29 ax_loss.set_ylabel("Loss")
39 self.loss_fn = loss_fn 30
40 self.callbacks = callbacks 31 ax_acc.plot(lrs, accs, color='blue')
41 32 ax_acc.set_xscale("log")
42 # self.model_state = copy.deepcopy(model.state_dict()) 33 ax_acc.set_ylabel("Accuracy")
43 # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) 34
44 35 plt.savefig(output_file, dpi=300)
45 def run( 36 plt.close()
46 self,
47 end_lr,
48 skip_start: int = 10,
49 skip_end: int = 5,
50 num_epochs: int = 100,
51 num_train_batches: int = math.inf,
52 num_val_batches: int = math.inf,
53 smooth_f: float = 0.05,
54 ):
55 best_loss = None
56 best_acc = None
57
58 lrs = []
59 losses = []
60 accs = []
61
62 lr_scheduler = get_exponential_schedule(
63 self.optimizer,
64 end_lr,
65 num_epochs * min(num_train_batches, len(self.train_dataloader))
66 )
67
68 steps = min(num_train_batches, len(self.train_dataloader))
69 steps += min(num_val_batches, len(self.val_dataloader))
70 steps *= num_epochs
71
72 progress_bar = tqdm(
73 range(steps),
74 disable=not self.accelerator.is_local_main_process,
75 dynamic_ncols=True
76 )
77 progress_bar.set_description("Epoch X / Y")
78
79 self.callbacks.on_prepare()
80
81 on_train = self.callbacks.on_train
82 on_before_optimize = self.callbacks.on_before_optimize
83 on_after_optimize = self.callbacks.on_after_optimize
84 on_eval = self.callbacks.on_eval
85
86 for epoch in range(num_epochs):
87 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
88
89 avg_loss = AverageMeter()
90 avg_acc = AverageMeter()
91
92 self.model.train()
93
94 with on_train(epoch):
95 for step, batch in enumerate(self.train_dataloader):
96 if step >= num_train_batches:
97 break
98
99 with self.accelerator.accumulate(self.model):
100 loss, acc, bsz = self.loss_fn(step, batch)
101
102 self.accelerator.backward(loss)
103
104 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
105
106 self.optimizer.step()
107 lr_scheduler.step()
108 self.optimizer.zero_grad(set_to_none=True)
109
110 if self.accelerator.sync_gradients:
111 on_after_optimize(lr_scheduler.get_last_lr()[0])
112
113 progress_bar.update(1)
114
115 self.model.eval()
116
117 with torch.inference_mode():
118 with on_eval():
119 for step, batch in enumerate(self.val_dataloader):
120 if step >= num_val_batches:
121 break
122
123 loss, acc, bsz = self.loss_fn(step, batch, True)
124 avg_loss.update(loss.detach_(), bsz)
125 avg_acc.update(acc.detach_(), bsz)
126
127 progress_bar.update(1)
128
129 loss = avg_loss.avg.item()
130 acc = avg_acc.avg.item()
131
132 if epoch == 0:
133 best_loss = loss
134 best_acc = acc
135 else:
136 if smooth_f > 0:
137 loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
138 acc = smooth_f * acc + (1 - smooth_f) * accs[-1]
139 if loss < best_loss:
140 best_loss = loss
141 if acc > best_acc:
142 best_acc = acc
143
144 lr = lr_scheduler.get_last_lr()[0]
145
146 lrs.append(lr)
147 losses.append(loss)
148 accs.append(acc)
149
150 self.accelerator.log({
151 "loss": loss,
152 "acc": acc,
153 "lr": lr,
154 }, step=epoch)
155
156 progress_bar.set_postfix({
157 "loss": loss,
158 "loss/best": best_loss,
159 "acc": acc,
160 "acc/best": best_acc,
161 "lr": lr,
162 })
163
164 # self.model.load_state_dict(self.model_state)
165 # self.optimizer.load_state_dict(self.optimizer_state)
166
167 if skip_end == 0:
168 lrs = lrs[skip_start:]
169 losses = losses[skip_start:]
170 accs = accs[skip_start:]
171 else:
172 lrs = lrs[skip_start:-skip_end]
173 losses = losses[skip_start:-skip_end]
174 accs = accs[skip_start:-skip_end]
175
176 fig, ax_loss = plt.subplots()
177 ax_acc = ax_loss.twinx()
178
179 ax_loss.plot(lrs, losses, color='red')
180 ax_loss.set_xscale("log")
181 ax_loss.set_xlabel(f"Learning rate")
182 ax_loss.set_ylabel("Loss")
183
184 ax_acc.plot(lrs, accs, color='blue')
185 ax_acc.set_xscale("log")
186 ax_acc.set_ylabel("Accuracy")
187
188 print("LR suggestion: steepest gradient")
189 min_grad_idx = None
190
191 try:
192 min_grad_idx = np.gradient(np.array(losses)).argmin()
193 except ValueError:
194 print(
195 "Failed to compute the gradients, there might not be enough points."
196 )
197
198 try:
199 max_val_idx = np.array(accs).argmax()
200 except ValueError:
201 print(
202 "Failed to compute the gradients, there might not be enough points."
203 )
204
205 if min_grad_idx is not None:
206 print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx]))
207 ax_loss.scatter(
208 lrs[min_grad_idx],
209 losses[min_grad_idx],
210 s=75,
211 marker="o",
212 color="red",
213 zorder=3,
214 label="steepest gradient",
215 )
216 ax_loss.legend()
217
218 if max_val_idx is not None:
219 print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx]))
220 ax_acc.scatter(
221 lrs[max_val_idx],
222 accs[max_val_idx],
223 s=75,
224 marker="o",
225 color="blue",
226 zorder=3,
227 label="maximum",
228 )
229 ax_acc.legend()
230
231
232def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1):
233 def lr_lambda(base_lr: float, current_epoch: int):
234 return (end_lr / base_lr) ** (current_epoch / num_epochs)
235
236 lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups]
237
238 return LambdaLR(optimizer, lr_lambdas, last_epoch)