summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/lr.py115
1 files changed, 115 insertions, 0 deletions
diff --git a/training/lr.py b/training/lr.py
new file mode 100644
index 0000000..dd37baa
--- /dev/null
+++ b/training/lr.py
@@ -0,0 +1,115 @@
1import numpy as np
2from torch.optim.lr_scheduler import LambdaLR
3from tqdm.auto import tqdm
4import matplotlib.pyplot as plt
5
6from training.util import AverageMeter
7
8
9class LRFinder():
10 def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn):
11 self.accelerator = accelerator
12 self.model = model
13 self.optimizer = optimizer
14 self.train_dataloader = train_dataloader
15 self.loss_fn = loss_fn
16
17 def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5):
18 best_loss = None
19 lrs = []
20 losses = []
21
22 lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs)
23
24 progress_bar = tqdm(
25 range(num_epochs * num_steps),
26 disable=not self.accelerator.is_local_main_process,
27 dynamic_ncols=True
28 )
29 progress_bar.set_description("Epoch X / Y")
30
31 for epoch in range(num_epochs):
32 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
33
34 avg_loss = AverageMeter()
35
36 for step, batch in enumerate(self.train_dataloader):
37 with self.accelerator.accumulate(self.model):
38 loss, acc, bsz = self.loss_fn(batch)
39
40 self.accelerator.backward(loss)
41
42 self.optimizer.step()
43 self.optimizer.zero_grad(set_to_none=True)
44
45 avg_loss.update(loss.detach_(), bsz)
46
47 if step >= num_steps:
48 break
49
50 if self.accelerator.sync_gradients:
51 progress_bar.update(1)
52
53 lr_scheduler.step()
54
55 loss = avg_loss.avg.item()
56 if epoch == 0:
57 best_loss = loss
58 else:
59 if smooth_f > 0:
60 loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
61 if loss < best_loss:
62 best_loss = loss
63
64 lr = lr_scheduler.get_last_lr()[0]
65
66 lrs.append(lr)
67 losses.append(loss)
68
69 progress_bar.set_postfix({
70 "loss": loss,
71 "best": best_loss,
72 "lr": lr,
73 })
74
75 if loss > diverge_th * best_loss:
76 print("Stopping early, the loss has diverged")
77 break
78
79 fig, ax = plt.subplots()
80 ax.plot(lrs, losses)
81
82 print("LR suggestion: steepest gradient")
83 min_grad_idx = None
84 try:
85 min_grad_idx = (np.gradient(np.array(losses))).argmin()
86 except ValueError:
87 print(
88 "Failed to compute the gradients, there might not be enough points."
89 )
90 if min_grad_idx is not None:
91 print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
92 ax.scatter(
93 lrs[min_grad_idx],
94 losses[min_grad_idx],
95 s=75,
96 marker="o",
97 color="red",
98 zorder=3,
99 label="steepest gradient",
100 )
101 ax.legend()
102
103 ax.set_xscale("log")
104 ax.set_xlabel("Learning rate")
105 ax.set_ylabel("Loss")
106
107 if fig is not None:
108 plt.show()
109
110
111def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1):
112 def lr_lambda(current_epoch: int):
113 return (current_epoch / num_epochs) ** 5
114
115 return LambdaLR(optimizer, lr_lambda, last_epoch)