diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
| commit | f963d4cba5c4c6575d77be80621a40b615603ca3 (patch) | |
| tree | d60ecb8c99534a12cef8070b0cf5a77eecc1c8d1 /train_ti.py | |
| parent | Fixed TI decay (diff) | |
| download | textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.gz textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.bz2 textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 30 |
1 files changed, 14 insertions, 16 deletions
diff --git a/train_ti.py b/train_ti.py index 890c465..9ec5cfb 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -452,27 +452,27 @@ def parse_args(): | |||
| 452 | if args.project is None: | 452 | if args.project is None: |
| 453 | raise ValueError("You must specify --project") | 453 | raise ValueError("You must specify --project") |
| 454 | 454 | ||
| 455 | if isinstance(args.initializer_token, str): | ||
| 456 | args.initializer_token = [args.initializer_token] | ||
| 457 | |||
| 458 | if len(args.initializer_token) == 0: | ||
| 459 | raise ValueError("You must specify --initializer_token") | ||
| 460 | |||
| 461 | if isinstance(args.placeholder_token, str): | 455 | if isinstance(args.placeholder_token, str): |
| 462 | args.placeholder_token = [args.placeholder_token] | 456 | args.placeholder_token = [args.placeholder_token] |
| 463 | 457 | ||
| 464 | if len(args.placeholder_token) == 0: | 458 | if len(args.placeholder_token) == 0: |
| 465 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 459 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
| 466 | 460 | ||
| 461 | if isinstance(args.initializer_token, str): | ||
| 462 | args.initializer_token = [args.initializer_token] * len(args.placeholder_token) | ||
| 463 | |||
| 464 | if len(args.initializer_token) == 0: | ||
| 465 | raise ValueError("You must specify --initializer_token") | ||
| 466 | |||
| 467 | if len(args.placeholder_token) != len(args.initializer_token): | ||
| 468 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") | ||
| 469 | |||
| 467 | if args.num_vectors is None: | 470 | if args.num_vectors is None: |
| 468 | args.num_vectors = 1 | 471 | args.num_vectors = 1 |
| 469 | 472 | ||
| 470 | if isinstance(args.num_vectors, int): | 473 | if isinstance(args.num_vectors, int): |
| 471 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | 474 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) |
| 472 | 475 | ||
| 473 | if len(args.placeholder_token) != len(args.initializer_token): | ||
| 474 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") | ||
| 475 | |||
| 476 | if len(args.placeholder_token) != len(args.num_vectors): | 476 | if len(args.placeholder_token) != len(args.num_vectors): |
| 477 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | 477 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") |
| 478 | 478 | ||
| @@ -867,7 +867,7 @@ def main(): | |||
| 867 | pass | 867 | pass |
| 868 | 868 | ||
| 869 | @torch.no_grad() | 869 | @torch.no_grad() |
| 870 | def on_clip(lr): | 870 | def on_after_optimize(lr: float): |
| 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) |
| 872 | 872 | ||
| 873 | loop = partial( | 873 | loop = partial( |
| @@ -904,7 +904,7 @@ def main(): | |||
| 904 | loop, | 904 | loop, |
| 905 | on_train=on_train, | 905 | on_train=on_train, |
| 906 | on_eval=on_eval, | 906 | on_eval=on_eval, |
| 907 | on_clip=on_clip, | 907 | on_after_optimize=on_after_optimize, |
| 908 | ) | 908 | ) |
| 909 | lr_finder.run(num_epochs=100, end_lr=1e3) | 909 | lr_finder.run(num_epochs=100, end_lr=1e3) |
| 910 | 910 | ||
| @@ -985,12 +985,8 @@ def main(): | |||
| 985 | 985 | ||
| 986 | accelerator.backward(loss) | 986 | accelerator.backward(loss) |
| 987 | 987 | ||
| 988 | if accelerator.sync_gradients: | ||
| 989 | on_clip(lr_scheduler.get_last_lr()[0]) | ||
| 990 | |||
| 991 | optimizer.step() | 988 | optimizer.step() |
| 992 | if not accelerator.optimizer_step_was_skipped: | 989 | lr_scheduler.step() |
| 993 | lr_scheduler.step() | ||
| 994 | optimizer.zero_grad(set_to_none=True) | 990 | optimizer.zero_grad(set_to_none=True) |
| 995 | 991 | ||
| 996 | avg_loss.update(loss.detach_(), bsz) | 992 | avg_loss.update(loss.detach_(), bsz) |
| @@ -998,6 +994,8 @@ def main(): | |||
| 998 | 994 | ||
| 999 | # Checks if the accelerator has performed an optimization step behind the scenes | 995 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 1000 | if accelerator.sync_gradients: | 996 | if accelerator.sync_gradients: |
| 997 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 998 | |||
| 1001 | if args.use_ema: | 999 | if args.use_ema: |
| 1002 | ema_embeddings.step( | 1000 | ema_embeddings.step( |
| 1003 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 1001 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
