From f963d4cba5c4c6575d77be80621a40b615603ca3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 12 Jan 2023 13:50:22 +0100 Subject: Update --- train_dreambooth.py | 56 +++++++++++++++++++++++++++++++++++++++-------------- train_ti.py | 30 ++++++++++++++-------------- training/lr.py | 11 +++++++---- 3 files changed, 63 insertions(+), 34 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 73d9935..ebcf802 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -86,6 +86,12 @@ def parse_args(): default=[], help="A token to use as initializer word." ) + parser.add_argument( + "--num_vectors", + type=int, + nargs='*', + help="Number of vectors per embedding." + ) parser.add_argument( "--exclude_collections", type=str, @@ -444,17 +450,29 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") - if isinstance(args.initializer_token, str): - args.initializer_token = [args.initializer_token] - if isinstance(args.placeholder_token, str): args.placeholder_token = [args.placeholder_token] if len(args.placeholder_token) == 0: - args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] + args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + + if isinstance(args.initializer_token, str): + args.initializer_token = [args.initializer_token] * len(args.placeholder_token) + + if len(args.initializer_token) == 0: + raise ValueError("You must specify --initializer_token") if len(args.placeholder_token) != len(args.initializer_token): - raise ValueError("Number of items in --placeholder_token and --initializer_token must match") + raise ValueError("--placeholder_token and --initializer_token must have the same number of items") + + if args.num_vectors is None: + args.num_vectors = 1 + + if isinstance(args.num_vectors, int): + args.num_vectors = [args.num_vectors] * len(args.initializer_token) + + if len(args.placeholder_token) != len(args.num_vectors): + raise ValueError("--placeholder_token and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] @@ -882,6 +900,18 @@ def main(): finally: pass + def on_before_optimize(): + if accelerator.sync_gradients: + params_to_clip = [unet.parameters()] + if args.train_text_encoder and epoch < args.train_text_encoder_epochs: + params_to_clip.append(text_encoder.parameters()) + accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) + + @torch.no_grad() + def on_after_optimize(lr: float): + if not args.train_text_encoder: + text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) + loop = partial( loss_step, vae, @@ -915,10 +945,12 @@ def main(): loop, on_train=tokenizer.train, on_eval=tokenizer.eval, + on_before_optimize=on_before_optimize, + on_after_optimize=on_after_optimize, ) - lr_finder.run(end_lr=1e2) + lr_finder.run(num_epochs=100, end_lr=1e3) - plt.savefig(basepath.joinpath("lr.png")) + plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() quit() @@ -999,13 +1031,7 @@ def main(): accelerator.backward(loss) - if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder and epoch < args.train_text_encoder_epochs - else unet.parameters() - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + on_before_optimize() optimizer.step() if not accelerator.optimizer_step_was_skipped: @@ -1019,6 +1045,8 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + on_after_optimize(lr_scheduler.get_last_lr()[0]) + local_progress_bar.update(1) global_progress_bar.update(1) diff --git a/train_ti.py b/train_ti.py index 890c465..9ec5cfb 100644 --- a/train_ti.py +++ b/train_ti.py @@ -452,27 +452,27 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") - if isinstance(args.initializer_token, str): - args.initializer_token = [args.initializer_token] - - if len(args.initializer_token) == 0: - raise ValueError("You must specify --initializer_token") - if isinstance(args.placeholder_token, str): args.placeholder_token = [args.placeholder_token] if len(args.placeholder_token) == 0: args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + if isinstance(args.initializer_token, str): + args.initializer_token = [args.initializer_token] * len(args.placeholder_token) + + if len(args.initializer_token) == 0: + raise ValueError("You must specify --initializer_token") + + if len(args.placeholder_token) != len(args.initializer_token): + raise ValueError("--placeholder_token and --initializer_token must have the same number of items") + if args.num_vectors is None: args.num_vectors = 1 if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.initializer_token) - if len(args.placeholder_token) != len(args.initializer_token): - raise ValueError("--placeholder_token and --initializer_token must have the same number of items") - if len(args.placeholder_token) != len(args.num_vectors): raise ValueError("--placeholder_token and --num_vectors must have the same number of items") @@ -867,7 +867,7 @@ def main(): pass @torch.no_grad() - def on_clip(lr): + def on_after_optimize(lr: float): text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) loop = partial( @@ -904,7 +904,7 @@ def main(): loop, on_train=on_train, on_eval=on_eval, - on_clip=on_clip, + on_after_optimize=on_after_optimize, ) lr_finder.run(num_epochs=100, end_lr=1e3) @@ -985,12 +985,8 @@ def main(): accelerator.backward(loss) - if accelerator.sync_gradients: - on_clip(lr_scheduler.get_last_lr()[0]) - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) avg_loss.update(loss.detach_(), bsz) @@ -998,6 +994,8 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + on_after_optimize(lr_scheduler.get_last_lr()[0]) + if args.use_ema: ema_embeddings.step( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) diff --git a/training/lr.py b/training/lr.py index 01f7f5e..84e30a0 100644 --- a/training/lr.py +++ b/training/lr.py @@ -26,7 +26,8 @@ class LRFinder(): val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], on_train: Callable[[], _GeneratorContextManager] = nullcontext, - on_clip: Callable[[float], None] = noop, + on_before_optimize: Callable[[], None] = noop, + on_after_optimize: Callable[[float], None] = noop, on_eval: Callable[[], _GeneratorContextManager] = nullcontext ): self.accelerator = accelerator @@ -36,7 +37,8 @@ class LRFinder(): self.val_dataloader = val_dataloader self.loss_fn = loss_fn self.on_train = on_train - self.on_clip = on_clip + self.on_before_optimize = on_before_optimize + self.on_after_optimize = on_after_optimize self.on_eval = on_eval # self.model_state = copy.deepcopy(model.state_dict()) @@ -94,14 +96,15 @@ class LRFinder(): self.accelerator.backward(loss) - if self.accelerator.sync_gradients: - self.on_clip(lr_scheduler.get_last_lr()[0]) + self.on_before_optimize() self.optimizer.step() lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) if self.accelerator.sync_gradients: + self.on_after_optimize(lr_scheduler.get_last_lr()[0]) + progress_bar.update(1) self.model.eval() -- cgit v1.2.3-70-g09d2