import math from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union from functools import partial import matplotlib.pyplot as plt import numpy as np import torch from torch.optim.lr_scheduler import LambdaLR from tqdm.auto import tqdm from training.util import AverageMeter def noop(*args, **kwards): pass class LRFinder(): def __init__( self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], on_train: Callable[[], _GeneratorContextManager] = nullcontext, on_clip: Callable[[], None] = noop, on_eval: Callable[[], _GeneratorContextManager] = nullcontext ): self.accelerator = accelerator self.model = model self.optimizer = optimizer self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.loss_fn = loss_fn self.on_train = on_train self.on_clip = on_clip self.on_eval = on_eval # self.model_state = copy.deepcopy(model.state_dict()) # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) def run( self, end_lr, skip_start: int = 10, skip_end: int = 5, num_epochs: int = 100, num_train_batches: int = 1, num_val_batches: int = math.inf, smooth_f: float = 0.05, ): best_loss = None best_acc = None lrs = [] losses = [] accs = [] lr_scheduler = get_exponential_schedule( self.optimizer, end_lr, num_epochs * min(num_train_batches, len(self.train_dataloader)) ) steps = min(num_train_batches, len(self.train_dataloader)) steps += min(num_val_batches, len(self.val_dataloader)) steps *= num_epochs progress_bar = tqdm( range(steps), disable=not self.accelerator.is_local_main_process, dynamic_ncols=True ) progress_bar.set_description("Epoch X / Y") for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") avg_loss = AverageMeter() avg_acc = AverageMeter() self.model.train() with self.on_train(): for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: break with self.accelerator.accumulate(self.model): loss, acc, bsz = self.loss_fn(step, batch) self.accelerator.backward(loss) if self.accelerator.sync_gradients: self.on_clip() self.optimizer.step() lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) if self.accelerator.sync_gradients: progress_bar.update(1) self.model.eval() with torch.inference_mode(): with self.on_eval(): for step, batch in enumerate(self.val_dataloader): if step >= num_val_batches: break loss, acc, bsz = self.loss_fn(step, batch, True) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) progress_bar.update(1) loss = avg_loss.avg.item() acc = avg_acc.avg.item() if epoch == 0: best_loss = loss best_acc = acc else: if smooth_f > 0: loss = smooth_f * loss + (1 - smooth_f) * losses[-1] acc = smooth_f * acc + (1 - smooth_f) * accs[-1] if loss < best_loss: best_loss = loss if acc > best_acc: best_acc = acc lr = lr_scheduler.get_last_lr()[0] lrs.append(lr) losses.append(loss) accs.append(acc) self.accelerator.log({ "loss": loss, "acc": acc, "lr": lr, }, step=epoch) progress_bar.set_postfix({ "loss": loss, "loss/best": best_loss, "acc": acc, "acc/best": best_acc, "lr": lr, }) # self.model.load_state_dict(self.model_state) # self.optimizer.load_state_dict(self.optimizer_state) fig, ax_loss = plt.subplots() ax_acc = ax_loss.twinx() ax_loss.plot(lrs, losses, color='red') ax_loss.set_xscale("log") ax_loss.set_xlabel(f"Learning rate") ax_loss.set_ylabel("Loss") ax_acc.plot(lrs, accs, color='blue') ax_acc.set_xscale("log") ax_acc.set_ylabel("Accuracy") print("LR suggestion: steepest gradient") min_grad_idx = None if skip_end == 0: lrs = lrs[skip_start:] losses = losses[skip_start:] accs = accs[skip_start:] else: lrs = lrs[skip_start:-skip_end] losses = losses[skip_start:-skip_end] accs = accs[skip_start:-skip_end] try: min_grad_idx = np.gradient(np.array(losses)).argmin() except ValueError: print( "Failed to compute the gradients, there might not be enough points." ) try: max_val_idx = np.array(accs).argmax() except ValueError: print( "Failed to compute the gradients, there might not be enough points." ) if min_grad_idx is not None: print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) ax_loss.scatter( lrs[min_grad_idx], losses[min_grad_idx], s=75, marker="o", color="red", zorder=3, label="steepest gradient", ) ax_loss.legend() if max_val_idx is not None: print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) ax_acc.scatter( lrs[max_val_idx], accs[max_val_idx], s=75, marker="o", color="blue", zorder=3, label="maximum", ) ax_acc.legend() def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): def lr_lambda(base_lr: float, current_epoch: int): return (end_lr / base_lr) ** (current_epoch / num_epochs) lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] return LambdaLR(optimizer, lr_lambdas, last_epoch)