1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR
from tqdm.auto import tqdm
from training.util import AverageMeter
class LRFinder():
def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn):
self.accelerator = accelerator
self.model = model
self.optimizer = optimizer
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.loss_fn = loss_fn
def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5):
best_loss = None
lrs = []
losses = []
lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs)
progress_bar = tqdm(
range(num_epochs * (num_train_steps + num_val_steps)),
disable=not self.accelerator.is_local_main_process,
dynamic_ncols=True
)
progress_bar.set_description("Epoch X / Y")
for epoch in range(num_epochs):
progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
avg_loss = AverageMeter()
self.model.train()
for step, batch in enumerate(self.train_dataloader):
with self.accelerator.accumulate(self.model):
loss, acc, bsz = self.loss_fn(batch)
self.accelerator.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
if self.accelerator.sync_gradients:
progress_bar.update(1)
if step >= num_train_steps:
break
self.model.eval()
with torch.inference_mode():
for step, batch in enumerate(self.val_dataloader):
loss, acc, bsz = self.loss_fn(batch)
avg_loss.update(loss.detach_(), bsz)
if self.accelerator.sync_gradients:
progress_bar.update(1)
if step >= num_val_steps:
break
lr_scheduler.step()
loss = avg_loss.avg.item()
if epoch == 0:
best_loss = loss
else:
if smooth_f > 0:
loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
if loss < best_loss:
best_loss = loss
lr = lr_scheduler.get_last_lr()[0]
lrs.append(lr)
losses.append(loss)
progress_bar.set_postfix({
"loss": loss,
"best": best_loss,
"lr": lr,
})
if loss > diverge_th * best_loss:
print("Stopping early, the loss has diverged")
break
fig, ax = plt.subplots()
ax.plot(lrs, losses)
print("LR suggestion: steepest gradient")
min_grad_idx = None
try:
min_grad_idx = (np.gradient(np.array(losses))).argmin()
except ValueError:
print(
"Failed to compute the gradients, there might not be enough points."
)
if min_grad_idx is not None:
print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
ax.scatter(
lrs[min_grad_idx],
losses[min_grad_idx],
s=75,
marker="o",
color="red",
zorder=3,
label="steepest gradient",
)
ax.legend()
ax.set_xscale("log")
ax.set_xlabel("Learning rate")
ax.set_ylabel("Loss")
def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1):
def lr_lambda(current_epoch: int):
return (current_epoch / num_epochs) ** 5
return LambdaLR(optimizer, lr_lambda, last_epoch)
|