import math import copy import matplotlib.pyplot as plt import numpy as np import torch from torch.optim.lr_scheduler import LambdaLR from tqdm.auto import tqdm from training.util import AverageMeter class LRFinder(): def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): self.accelerator = accelerator self.model = model self.optimizer = optimizer self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.loss_fn = loss_fn # self.model_state = copy.deepcopy(model.state_dict()) # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): best_loss = None best_acc = None lrs = [] losses = [] accs = [] lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) steps = min(num_train_batches, len(self.train_dataloader)) steps += min(num_val_batches, len(self.val_dataloader)) steps *= num_epochs progress_bar = tqdm( range(steps), disable=not self.accelerator.is_local_main_process, dynamic_ncols=True ) progress_bar.set_description("Epoch X / Y") for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") avg_loss = AverageMeter() avg_acc = AverageMeter() self.model.train() for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: break with self.accelerator.accumulate(self.model): loss, acc, bsz = self.loss_fn(batch) self.accelerator.backward(loss) self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) if self.accelerator.sync_gradients: progress_bar.update(1) self.model.eval() with torch.inference_mode(): for step, batch in enumerate(self.val_dataloader): if step >= num_val_batches: break loss, acc, bsz = self.loss_fn(batch) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) progress_bar.update(1) lr_scheduler.step() loss = avg_loss.avg.item() acc = avg_acc.avg.item() if epoch == 0: best_loss = loss best_acc = acc else: if smooth_f > 0: loss = smooth_f * loss + (1 - smooth_f) * losses[-1] acc = smooth_f * acc + (1 - smooth_f) * accs[-1] if loss < best_loss: best_loss = loss if acc > best_acc: best_acc = acc lr = lr_scheduler.get_last_lr()[0] lrs.append(lr) losses.append(loss) accs.append(acc) progress_bar.set_postfix({ "loss": loss, "loss/best": best_loss, "acc": acc, "acc/best": best_acc, "lr": lr, }) # self.model.load_state_dict(self.model_state) # self.optimizer.load_state_dict(self.optimizer_state) if loss > diverge_th * best_loss: print("Stopping early, the loss has diverged") break if skip_end == 0: lrs = lrs[skip_start:] losses = losses[skip_start:] accs = accs[skip_start:] else: lrs = lrs[skip_start:-skip_end] losses = losses[skip_start:-skip_end] accs = accs[skip_start:-skip_end] fig, ax_loss = plt.subplots() ax_loss.plot(lrs, losses, color='red') ax_loss.set_xscale("log") ax_loss.set_xlabel("Learning rate") ax_loss.set_ylabel("Loss") ax_acc = ax_loss.twinx() ax_acc.plot(lrs, accs, color='blue') ax_acc.set_ylabel("Accuracy") print("LR suggestion: steepest gradient") min_grad_idx = None try: min_grad_idx = (np.gradient(np.array(losses))).argmin() except ValueError: print( "Failed to compute the gradients, there might not be enough points." ) if min_grad_idx is not None: print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) ax_loss.scatter( lrs[min_grad_idx], losses[min_grad_idx], s=75, marker="o", color="red", zorder=3, label="steepest gradient", ) ax_loss.legend() def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): def lr_lambda(current_epoch: int): return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) return LambdaLR(optimizer, lr_lambda, last_epoch)