From 9d6c75262b6919758e781b8333428861a5bf7ede Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Dec 2022 11:02:49 +0100 Subject: Added learning rate finder --- training/lr.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 training/lr.py (limited to 'training') diff --git a/training/lr.py b/training/lr.py new file mode 100644 index 0000000..dd37baa --- /dev/null +++ b/training/lr.py @@ -0,0 +1,115 @@ +import numpy as np +from torch.optim.lr_scheduler import LambdaLR +from tqdm.auto import tqdm +import matplotlib.pyplot as plt + +from training.util import AverageMeter + + +class LRFinder(): + def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn): + self.accelerator = accelerator + self.model = model + self.optimizer = optimizer + self.train_dataloader = train_dataloader + self.loss_fn = loss_fn + + def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5): + best_loss = None + lrs = [] + losses = [] + + lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) + + progress_bar = tqdm( + range(num_epochs * num_steps), + disable=not self.accelerator.is_local_main_process, + dynamic_ncols=True + ) + progress_bar.set_description("Epoch X / Y") + + for epoch in range(num_epochs): + progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") + + avg_loss = AverageMeter() + + for step, batch in enumerate(self.train_dataloader): + with self.accelerator.accumulate(self.model): + loss, acc, bsz = self.loss_fn(batch) + + self.accelerator.backward(loss) + + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + + avg_loss.update(loss.detach_(), bsz) + + if step >= num_steps: + break + + if self.accelerator.sync_gradients: + progress_bar.update(1) + + lr_scheduler.step() + + loss = avg_loss.avg.item() + if epoch == 0: + best_loss = loss + else: + if smooth_f > 0: + loss = smooth_f * loss + (1 - smooth_f) * losses[-1] + if loss < best_loss: + best_loss = loss + + lr = lr_scheduler.get_last_lr()[0] + + lrs.append(lr) + losses.append(loss) + + progress_bar.set_postfix({ + "loss": loss, + "best": best_loss, + "lr": lr, + }) + + if loss > diverge_th * best_loss: + print("Stopping early, the loss has diverged") + break + + fig, ax = plt.subplots() + ax.plot(lrs, losses) + + print("LR suggestion: steepest gradient") + min_grad_idx = None + try: + min_grad_idx = (np.gradient(np.array(losses))).argmin() + except ValueError: + print( + "Failed to compute the gradients, there might not be enough points." + ) + if min_grad_idx is not None: + print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) + ax.scatter( + lrs[min_grad_idx], + losses[min_grad_idx], + s=75, + marker="o", + color="red", + zorder=3, + label="steepest gradient", + ) + ax.legend() + + ax.set_xscale("log") + ax.set_xlabel("Learning rate") + ax.set_ylabel("Loss") + + if fig is not None: + plt.show() + + +def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): + def lr_lambda(current_epoch: int): + return (current_epoch / num_epochs) ** 5 + + return LambdaLR(optimizer, lr_lambda, last_epoch) -- cgit v1.2.3-70-g09d2