diff options
author | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
commit | f963d4cba5c4c6575d77be80621a40b615603ca3 (patch) | |
tree | d60ecb8c99534a12cef8070b0cf5a77eecc1c8d1 /train_ti.py | |
parent | Fixed TI decay (diff) | |
download | textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.gz textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.bz2 textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 30 |
1 files changed, 14 insertions, 16 deletions
diff --git a/train_ti.py b/train_ti.py index 890c465..9ec5cfb 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -452,27 +452,27 @@ def parse_args(): | |||
452 | if args.project is None: | 452 | if args.project is None: |
453 | raise ValueError("You must specify --project") | 453 | raise ValueError("You must specify --project") |
454 | 454 | ||
455 | if isinstance(args.initializer_token, str): | ||
456 | args.initializer_token = [args.initializer_token] | ||
457 | |||
458 | if len(args.initializer_token) == 0: | ||
459 | raise ValueError("You must specify --initializer_token") | ||
460 | |||
461 | if isinstance(args.placeholder_token, str): | 455 | if isinstance(args.placeholder_token, str): |
462 | args.placeholder_token = [args.placeholder_token] | 456 | args.placeholder_token = [args.placeholder_token] |
463 | 457 | ||
464 | if len(args.placeholder_token) == 0: | 458 | if len(args.placeholder_token) == 0: |
465 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 459 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
466 | 460 | ||
461 | if isinstance(args.initializer_token, str): | ||
462 | args.initializer_token = [args.initializer_token] * len(args.placeholder_token) | ||
463 | |||
464 | if len(args.initializer_token) == 0: | ||
465 | raise ValueError("You must specify --initializer_token") | ||
466 | |||
467 | if len(args.placeholder_token) != len(args.initializer_token): | ||
468 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") | ||
469 | |||
467 | if args.num_vectors is None: | 470 | if args.num_vectors is None: |
468 | args.num_vectors = 1 | 471 | args.num_vectors = 1 |
469 | 472 | ||
470 | if isinstance(args.num_vectors, int): | 473 | if isinstance(args.num_vectors, int): |
471 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | 474 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) |
472 | 475 | ||
473 | if len(args.placeholder_token) != len(args.initializer_token): | ||
474 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") | ||
475 | |||
476 | if len(args.placeholder_token) != len(args.num_vectors): | 476 | if len(args.placeholder_token) != len(args.num_vectors): |
477 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | 477 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") |
478 | 478 | ||
@@ -867,7 +867,7 @@ def main(): | |||
867 | pass | 867 | pass |
868 | 868 | ||
869 | @torch.no_grad() | 869 | @torch.no_grad() |
870 | def on_clip(lr): | 870 | def on_after_optimize(lr: float): |
871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) |
872 | 872 | ||
873 | loop = partial( | 873 | loop = partial( |
@@ -904,7 +904,7 @@ def main(): | |||
904 | loop, | 904 | loop, |
905 | on_train=on_train, | 905 | on_train=on_train, |
906 | on_eval=on_eval, | 906 | on_eval=on_eval, |
907 | on_clip=on_clip, | 907 | on_after_optimize=on_after_optimize, |
908 | ) | 908 | ) |
909 | lr_finder.run(num_epochs=100, end_lr=1e3) | 909 | lr_finder.run(num_epochs=100, end_lr=1e3) |
910 | 910 | ||
@@ -985,12 +985,8 @@ def main(): | |||
985 | 985 | ||
986 | accelerator.backward(loss) | 986 | accelerator.backward(loss) |
987 | 987 | ||
988 | if accelerator.sync_gradients: | ||
989 | on_clip(lr_scheduler.get_last_lr()[0]) | ||
990 | |||
991 | optimizer.step() | 988 | optimizer.step() |
992 | if not accelerator.optimizer_step_was_skipped: | 989 | lr_scheduler.step() |
993 | lr_scheduler.step() | ||
994 | optimizer.zero_grad(set_to_none=True) | 990 | optimizer.zero_grad(set_to_none=True) |
995 | 991 | ||
996 | avg_loss.update(loss.detach_(), bsz) | 992 | avg_loss.update(loss.detach_(), bsz) |
@@ -998,6 +994,8 @@ def main(): | |||
998 | 994 | ||
999 | # Checks if the accelerator has performed an optimization step behind the scenes | 995 | # Checks if the accelerator has performed an optimization step behind the scenes |
1000 | if accelerator.sync_gradients: | 996 | if accelerator.sync_gradients: |
997 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
998 | |||
1001 | if args.use_ema: | 999 | if args.use_ema: |
1002 | ema_embeddings.step( | 1000 | ema_embeddings.step( |
1003 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 1001 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |