From f963d4cba5c4c6575d77be80621a40b615603ca3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 12 Jan 2023 13:50:22 +0100 Subject: Update --- train_dreambooth.py | 56 +++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 14 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 73d9935..ebcf802 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -86,6 +86,12 @@ def parse_args(): default=[], help="A token to use as initializer word." ) + parser.add_argument( + "--num_vectors", + type=int, + nargs='*', + help="Number of vectors per embedding." + ) parser.add_argument( "--exclude_collections", type=str, @@ -444,17 +450,29 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") - if isinstance(args.initializer_token, str): - args.initializer_token = [args.initializer_token] - if isinstance(args.placeholder_token, str): args.placeholder_token = [args.placeholder_token] if len(args.placeholder_token) == 0: - args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] + args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + + if isinstance(args.initializer_token, str): + args.initializer_token = [args.initializer_token] * len(args.placeholder_token) + + if len(args.initializer_token) == 0: + raise ValueError("You must specify --initializer_token") if len(args.placeholder_token) != len(args.initializer_token): - raise ValueError("Number of items in --placeholder_token and --initializer_token must match") + raise ValueError("--placeholder_token and --initializer_token must have the same number of items") + + if args.num_vectors is None: + args.num_vectors = 1 + + if isinstance(args.num_vectors, int): + args.num_vectors = [args.num_vectors] * len(args.initializer_token) + + if len(args.placeholder_token) != len(args.num_vectors): + raise ValueError("--placeholder_token and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] @@ -882,6 +900,18 @@ def main(): finally: pass + def on_before_optimize(): + if accelerator.sync_gradients: + params_to_clip = [unet.parameters()] + if args.train_text_encoder and epoch < args.train_text_encoder_epochs: + params_to_clip.append(text_encoder.parameters()) + accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) + + @torch.no_grad() + def on_after_optimize(lr: float): + if not args.train_text_encoder: + text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) + loop = partial( loss_step, vae, @@ -915,10 +945,12 @@ def main(): loop, on_train=tokenizer.train, on_eval=tokenizer.eval, + on_before_optimize=on_before_optimize, + on_after_optimize=on_after_optimize, ) - lr_finder.run(end_lr=1e2) + lr_finder.run(num_epochs=100, end_lr=1e3) - plt.savefig(basepath.joinpath("lr.png")) + plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() quit() @@ -999,13 +1031,7 @@ def main(): accelerator.backward(loss) - if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder and epoch < args.train_text_encoder_epochs - else unet.parameters() - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + on_before_optimize() optimizer.step() if not accelerator.optimizer_step_was_skipped: @@ -1019,6 +1045,8 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + on_after_optimize(lr_scheduler.get_last_lr()[0]) + local_progress_bar.update(1) global_progress_bar.update(1) -- cgit v1.2.3-54-g00ecf