From f963d4cba5c4c6575d77be80621a40b615603ca3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 12 Jan 2023 13:50:22 +0100 Subject: Update --- train_ti.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 890c465..9ec5cfb 100644 --- a/train_ti.py +++ b/train_ti.py @@ -452,27 +452,27 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") - if isinstance(args.initializer_token, str): - args.initializer_token = [args.initializer_token] - - if len(args.initializer_token) == 0: - raise ValueError("You must specify --initializer_token") - if isinstance(args.placeholder_token, str): args.placeholder_token = [args.placeholder_token] if len(args.placeholder_token) == 0: args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + if isinstance(args.initializer_token, str): + args.initializer_token = [args.initializer_token] * len(args.placeholder_token) + + if len(args.initializer_token) == 0: + raise ValueError("You must specify --initializer_token") + + if len(args.placeholder_token) != len(args.initializer_token): + raise ValueError("--placeholder_token and --initializer_token must have the same number of items") + if args.num_vectors is None: args.num_vectors = 1 if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.initializer_token) - if len(args.placeholder_token) != len(args.initializer_token): - raise ValueError("--placeholder_token and --initializer_token must have the same number of items") - if len(args.placeholder_token) != len(args.num_vectors): raise ValueError("--placeholder_token and --num_vectors must have the same number of items") @@ -867,7 +867,7 @@ def main(): pass @torch.no_grad() - def on_clip(lr): + def on_after_optimize(lr: float): text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) loop = partial( @@ -904,7 +904,7 @@ def main(): loop, on_train=on_train, on_eval=on_eval, - on_clip=on_clip, + on_after_optimize=on_after_optimize, ) lr_finder.run(num_epochs=100, end_lr=1e3) @@ -985,12 +985,8 @@ def main(): accelerator.backward(loss) - if accelerator.sync_gradients: - on_clip(lr_scheduler.get_last_lr()[0]) - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) avg_loss.update(loss.detach_(), bsz) @@ -998,6 +994,8 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + on_after_optimize(lr_scheduler.get_last_lr()[0]) + if args.use_ema: ema_embeddings.step( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) -- cgit v1.2.3-54-g00ecf