import numpy as np from torch.optim.lr_scheduler import LambdaLR from tqdm.auto import tqdm import matplotlib.pyplot as plt from training.util import AverageMeter class LRFinder(): def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn): self.accelerator = accelerator self.model = model self.optimizer = optimizer self.train_dataloader = train_dataloader self.loss_fn = loss_fn def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5): best_loss = None lrs = [] losses = [] lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) progress_bar = tqdm( range(num_epochs * num_steps), disable=not self.accelerator.is_local_main_process, dynamic_ncols=True ) progress_bar.set_description("Epoch X / Y") for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") avg_loss = AverageMeter() for step, batch in enumerate(self.train_dataloader): with self.accelerator.accumulate(self.model): loss, acc, bsz = self.loss_fn(batch) self.accelerator.backward(loss) self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) avg_loss.update(loss.detach_(), bsz) if step >= num_steps: break if self.accelerator.sync_gradients: progress_bar.update(1) lr_scheduler.step() loss = avg_loss.avg.item() if epoch == 0: best_loss = loss else: if smooth_f > 0: loss = smooth_f * loss + (1 - smooth_f) * losses[-1] if loss < best_loss: best_loss = loss lr = lr_scheduler.get_last_lr()[0] lrs.append(lr) losses.append(loss) progress_bar.set_postfix({ "loss": loss, "best": best_loss, "lr": lr, }) if loss > diverge_th * best_loss: print("Stopping early, the loss has diverged") break fig, ax = plt.subplots() ax.plot(lrs, losses) print("LR suggestion: steepest gradient") min_grad_idx = None try: min_grad_idx = (np.gradient(np.array(losses))).argmin() except ValueError: print( "Failed to compute the gradients, there might not be enough points." ) if min_grad_idx is not None: print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) ax.scatter( lrs[min_grad_idx], losses[min_grad_idx], s=75, marker="o", color="red", zorder=3, label="steepest gradient", ) ax.legend() ax.set_xscale("log") ax.set_xlabel("Learning rate") ax.set_ylabel("Loss") if fig is not None: plt.show() def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): def lr_lambda(current_epoch: int): return (current_epoch / num_epochs) ** 5 return LambdaLR(optimizer, lr_lambda, last_epoch)