From 46d631759f59bc6b65458202641e5f5a9bc30b7b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 20:13:59 +0100 Subject: Fixed LR finder --- train_dreambooth.py | 4 ++-- train_ti.py | 6 +++--- training/lr.py | 38 +++++++++++++++++++++----------------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 1e49474..218018b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -634,7 +634,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e2 + args.learning_rate = 1e-4 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -901,7 +901,7 @@ def main(): on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), on_eval=lambda: tokenizer.set_use_vector_shuffle(False) ) - lr_finder.run(min_lr=1e-4) + lr_finder.run(end_lr=1e2) plt.savefig(basepath.joinpath("lr.png")) plt.close() diff --git a/train_ti.py b/train_ti.py index 2b3f017..102c0fa 100644 --- a/train_ti.py +++ b/train_ti.py @@ -584,7 +584,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e2 + args.learning_rate = 1e-4 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -853,9 +853,9 @@ def main(): on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), on_eval=lambda: tokenizer.set_use_vector_shuffle(False) ) - lr_finder.run(min_lr=1e-4) + lr_finder.run(end_lr=1e2) - plt.savefig(basepath.joinpath("lr.png")) + plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() quit() diff --git a/training/lr.py b/training/lr.py index fe166ed..acc01a2 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,6 +1,7 @@ import math import copy from typing import Callable +from functools import partial import matplotlib.pyplot as plt import numpy as np @@ -41,7 +42,7 @@ class LRFinder(): def run( self, - min_lr, + end_lr, skip_start: int = 10, skip_end: int = 5, num_epochs: int = 100, @@ -57,7 +58,7 @@ class LRFinder(): losses = [] accs = [] - lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) + lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) steps = min(num_train_batches, len(self.train_dataloader)) steps += min(num_val_batches, len(self.val_dataloader)) @@ -152,29 +153,30 @@ class LRFinder(): print("Stopping early, the loss has diverged") break - if skip_end == 0: - lrs = lrs[skip_start:] - losses = losses[skip_start:] - accs = accs[skip_start:] - else: - lrs = lrs[skip_start:-skip_end] - losses = losses[skip_start:-skip_end] - accs = accs[skip_start:-skip_end] - fig, ax_loss = plt.subplots() + ax_acc = ax_loss.twinx() ax_loss.plot(lrs, losses, color='red') ax_loss.set_xscale("log") - ax_loss.set_xlabel("Learning rate") + ax_loss.set_xlabel(f"Learning rate") ax_loss.set_ylabel("Loss") - ax_acc = ax_loss.twinx() ax_acc.plot(lrs, accs, color='blue') + ax_acc.set_xscale("log") ax_acc.set_ylabel("Accuracy") print("LR suggestion: steepest gradient") min_grad_idx = None + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + accs = accs[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + accs = accs[skip_start:-skip_end] + try: min_grad_idx = (np.gradient(np.array(losses))).argmin() except ValueError: @@ -196,8 +198,10 @@ class LRFinder(): ax_loss.legend() -def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): - def lr_lambda(current_epoch: int): - return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) +def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): + def lr_lambda(base_lr: float, current_epoch: int): + return (end_lr / base_lr) ** (current_epoch / num_epochs) + + lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] - return LambdaLR(optimizer, lr_lambda, last_epoch) + return LambdaLR(optimizer, lr_lambdas, last_epoch) -- cgit v1.2.3-70-g09d2