From 83725794618164210a12843381724252fdd82cc2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Dec 2022 18:08:36 +0100 Subject: Integrated updates from diffusers --- training/lr.py | 46 ++++++++++++++++++++++++++++++++++++---------- training/util.py | 5 ----- 2 files changed, 36 insertions(+), 15 deletions(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index 8e558e1..c1fa3a0 100644 --- a/training/lr.py +++ b/training/lr.py @@ -22,10 +22,13 @@ class LRFinder(): self.model_state = copy.deepcopy(model.state_dict()) self.optimizer_state = copy.deepcopy(optimizer.state_dict()) - def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): + def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): best_loss = None + best_acc = None + lrs = [] losses = [] + accs = [] lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) @@ -44,6 +47,7 @@ class LRFinder(): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") avg_loss = AverageMeter() + avg_acc = AverageMeter() self.model.train() @@ -71,28 +75,37 @@ class LRFinder(): loss, acc, bsz = self.loss_fn(batch) avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) progress_bar.update(1) lr_scheduler.step() loss = avg_loss.avg.item() + acc = avg_acc.avg.item() + if epoch == 0: best_loss = loss + best_acc = acc else: if smooth_f > 0: loss = smooth_f * loss + (1 - smooth_f) * losses[-1] if loss < best_loss: best_loss = loss + if acc > best_acc: + best_acc = acc lr = lr_scheduler.get_last_lr()[0] lrs.append(lr) losses.append(loss) + accs.append(acc) progress_bar.set_postfix({ "loss": loss, - "best": best_loss, + "loss/best": best_loss, + "acc": acc, + "acc/best": best_acc, "lr": lr, }) @@ -103,20 +116,37 @@ class LRFinder(): print("Stopping early, the loss has diverged") break - fig, ax = plt.subplots() - ax.plot(lrs, losses) + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + accs = accs[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + accs = accs[skip_start:-skip_end] + + fig, ax_loss = plt.subplots() + + ax_loss.plot(lrs, losses, color='red', label='Loss') + ax_loss.set_xscale("log") + ax_loss.set_xlabel("Learning rate") + + # ax_acc = ax_loss.twinx() + # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') print("LR suggestion: steepest gradient") min_grad_idx = None + try: min_grad_idx = (np.gradient(np.array(losses))).argmin() except ValueError: print( "Failed to compute the gradients, there might not be enough points." ) + if min_grad_idx is not None: print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) - ax.scatter( + ax_loss.scatter( lrs[min_grad_idx], losses[min_grad_idx], s=75, @@ -125,11 +155,7 @@ class LRFinder(): zorder=3, label="steepest gradient", ) - ax.legend() - - ax.set_xscale("log") - ax.set_xlabel("Learning rate") - ax.set_ylabel("Loss") + ax_loss.legend() def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): diff --git a/training/util.py b/training/util.py index a0c15cd..d0f7fcd 100644 --- a/training/util.py +++ b/training/util.py @@ -5,11 +5,6 @@ import torch from PIL import Image -def freeze_params(params): - for param in params: - param.requires_grad = False - - def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) -- cgit v1.2.3-70-g09d2