From 30098b1d611853c0d3a4687d84582e1c1cf1b938 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Dec 2022 11:48:33 +0100 Subject: Added validation phase to learn rate finder --- train_ti.py | 12 ++++++++---- training/lr.py | 34 +++++++++++++++++++++++----------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/train_ti.py b/train_ti.py index ab00b60..32f44f4 100644 --- a/train_ti.py +++ b/train_ti.py @@ -14,6 +14,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup +import matplotlib.pyplot as plt from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify @@ -451,6 +452,7 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = Path(args.output_dir).joinpath(slugify(args.project), now) + basepath.mkdir(parents=True, exist_ok=True) if args.find_lr: accelerator = Accelerator( @@ -458,8 +460,6 @@ def main(): mixed_precision=args.mixed_precision ) else: - basepath.mkdir(parents=True, exist_ok=True) - accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{basepath}", @@ -782,8 +782,12 @@ def main(): return loss, acc, bsz if args.find_lr: - lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) - lr_finder.run() + lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) + lr_finder.run(num_train_steps=2) + + plt.savefig(basepath.joinpath("lr.png")) + plt.close() + quit() # We need to initialize the trackers we use, and also store our configuration. diff --git a/training/lr.py b/training/lr.py index dd37baa..5343f24 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,20 +1,22 @@ +import matplotlib.pyplot as plt import numpy as np +import torch from torch.optim.lr_scheduler import LambdaLR from tqdm.auto import tqdm -import matplotlib.pyplot as plt from training.util import AverageMeter class LRFinder(): - def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn): + def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): self.accelerator = accelerator self.model = model self.optimizer = optimizer self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader self.loss_fn = loss_fn - def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5): + def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5): best_loss = None lrs = [] losses = [] @@ -22,7 +24,7 @@ class LRFinder(): lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) progress_bar = tqdm( - range(num_epochs * num_steps), + range(num_epochs * (num_train_steps + num_val_steps)), disable=not self.accelerator.is_local_main_process, dynamic_ncols=True ) @@ -33,6 +35,8 @@ class LRFinder(): avg_loss = AverageMeter() + self.model.train() + for step, batch in enumerate(self.train_dataloader): with self.accelerator.accumulate(self.model): loss, acc, bsz = self.loss_fn(batch) @@ -42,13 +46,24 @@ class LRFinder(): self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) - avg_loss.update(loss.detach_(), bsz) + if self.accelerator.sync_gradients: + progress_bar.update(1) - if step >= num_steps: + if step >= num_train_steps: break - if self.accelerator.sync_gradients: - progress_bar.update(1) + self.model.eval() + + with torch.inference_mode(): + for step, batch in enumerate(self.val_dataloader): + loss, acc, bsz = self.loss_fn(batch) + avg_loss.update(loss.detach_(), bsz) + + if self.accelerator.sync_gradients: + progress_bar.update(1) + + if step >= num_val_steps: + break lr_scheduler.step() @@ -104,9 +119,6 @@ class LRFinder(): ax.set_xlabel("Learning rate") ax.set_ylabel("Loss") - if fig is not None: - plt.show() - def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): def lr_lambda(current_epoch: int): -- cgit v1.2.3-70-g09d2