From 67aaba2159bcda4c0b8538b1580a40f01e8f0964 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 17:34:11 +0100 Subject: Update --- train_dreambooth.py | 16 ++++++++++++---- train_ti.py | 15 ++++++++++++--- training/lr.py | 33 +++++++++++++++++++++++++++++++-- 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 05f6cb5..1e49474 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -565,8 +565,6 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') - tokenizer.set_use_vector_shuffle(args.vector_shuffle) - vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) @@ -893,7 +891,16 @@ def main(): accelerator.init_trackers("dreambooth", config=config) if args.find_lr: - lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) + lr_finder = LRFinder( + accelerator, + text_encoder, + optimizer, + train_dataloader, + val_dataloader, + loop, + on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), + on_eval=lambda: tokenizer.set_use_vector_shuffle(False) + ) lr_finder.run(min_lr=1e-4) plt.savefig(basepath.joinpath("lr.png")) @@ -965,11 +972,11 @@ def main(): local_progress_bar.reset() unet.train() - if epoch < args.train_text_encoder_epochs: text_encoder.train() elif epoch == args.train_text_encoder_epochs: text_encoder.requires_grad_(False) + tokenizer.set_use_vector_shuffle(args.vector_shuffle) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): @@ -1023,6 +1030,7 @@ def main(): unet.eval() text_encoder.eval() + tokenizer.set_use_vector_shuffle(False) cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() diff --git a/train_ti.py b/train_ti.py index 97dde1e..2b3f017 100644 --- a/train_ti.py +++ b/train_ti.py @@ -535,8 +535,6 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') - tokenizer.set_use_vector_shuffle(args.vector_shuffle) - vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) @@ -845,7 +843,16 @@ def main(): accelerator.init_trackers("textual_inversion", config=config) if args.find_lr: - lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) + lr_finder = LRFinder( + accelerator, + text_encoder, + optimizer, + train_dataloader, + val_dataloader, + loop, + on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), + on_eval=lambda: tokenizer.set_use_vector_shuffle(False) + ) lr_finder.run(min_lr=1e-4) plt.savefig(basepath.joinpath("lr.png")) @@ -915,6 +922,7 @@ def main(): local_progress_bar.reset() text_encoder.train() + tokenizer.set_use_vector_shuffle(args.vector_shuffle) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): @@ -955,6 +963,7 @@ def main(): accelerator.wait_for_everyone() text_encoder.eval() + tokenizer.set_use_vector_shuffle(False) cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() diff --git a/training/lr.py b/training/lr.py index 3abd2f2..fe166ed 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,5 +1,6 @@ import math import copy +from typing import Callable import matplotlib.pyplot as plt import numpy as np @@ -10,19 +11,45 @@ from tqdm.auto import tqdm from training.util import AverageMeter +def noop(): + pass + + class LRFinder(): - def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): + def __init__( + self, + accelerator, + model, + optimizer, + train_dataloader, + val_dataloader, + loss_fn, + on_train: Callable[[], None] = noop, + on_eval: Callable[[], None] = noop + ): self.accelerator = accelerator self.model = model self.optimizer = optimizer self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.loss_fn = loss_fn + self.on_train = on_train + self.on_eval = on_eval # self.model_state = copy.deepcopy(model.state_dict()) # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) - def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): + def run( + self, + min_lr, + skip_start: int = 10, + skip_end: int = 5, + num_epochs: int = 100, + num_train_batches: int = 1, + num_val_batches: int = math.inf, + smooth_f: float = 0.05, + diverge_th: int = 5 + ): best_loss = None best_acc = None @@ -50,6 +77,7 @@ class LRFinder(): avg_acc = AverageMeter() self.model.train() + self.on_train() for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: @@ -67,6 +95,7 @@ class LRFinder(): progress_bar.update(1) self.model.eval() + self.on_eval() with torch.inference_mode(): for step, batch in enumerate(self.val_dataloader): -- cgit v1.2.3-70-g09d2