From 30098b1d611853c0d3a4687d84582e1c1cf1b938 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Dec 2022 11:48:33 +0100 Subject: Added validation phase to learn rate finder --- train_ti.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index ab00b60..32f44f4 100644 --- a/train_ti.py +++ b/train_ti.py @@ -14,6 +14,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup +import matplotlib.pyplot as plt from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify @@ -451,6 +452,7 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = Path(args.output_dir).joinpath(slugify(args.project), now) + basepath.mkdir(parents=True, exist_ok=True) if args.find_lr: accelerator = Accelerator( @@ -458,8 +460,6 @@ def main(): mixed_precision=args.mixed_precision ) else: - basepath.mkdir(parents=True, exist_ok=True) - accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{basepath}", @@ -782,8 +782,12 @@ def main(): return loss, acc, bsz if args.find_lr: - lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) - lr_finder.run() + lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) + lr_finder.run(num_train_steps=2) + + plt.savefig(basepath.joinpath("lr.png")) + plt.close() + quit() # We need to initialize the trackers we use, and also store our configuration. -- cgit v1.2.3-54-g00ecf