diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
| commit | f963d4cba5c4c6575d77be80621a40b615603ca3 (patch) | |
| tree | d60ecb8c99534a12cef8070b0cf5a77eecc1c8d1 /train_dreambooth.py | |
| parent | Fixed TI decay (diff) | |
| download | textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.gz textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.bz2 textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.zip | |
Update
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 56 |
1 files changed, 42 insertions, 14 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 73d9935..ebcf802 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -87,6 +87,12 @@ def parse_args(): | |||
| 87 | help="A token to use as initializer word." | 87 | help="A token to use as initializer word." |
| 88 | ) | 88 | ) |
| 89 | parser.add_argument( | 89 | parser.add_argument( |
| 90 | "--num_vectors", | ||
| 91 | type=int, | ||
| 92 | nargs='*', | ||
| 93 | help="Number of vectors per embedding." | ||
| 94 | ) | ||
| 95 | parser.add_argument( | ||
| 90 | "--exclude_collections", | 96 | "--exclude_collections", |
| 91 | type=str, | 97 | type=str, |
| 92 | nargs='*', | 98 | nargs='*', |
| @@ -444,17 +450,29 @@ def parse_args(): | |||
| 444 | if args.project is None: | 450 | if args.project is None: |
| 445 | raise ValueError("You must specify --project") | 451 | raise ValueError("You must specify --project") |
| 446 | 452 | ||
| 447 | if isinstance(args.initializer_token, str): | ||
| 448 | args.initializer_token = [args.initializer_token] | ||
| 449 | |||
| 450 | if isinstance(args.placeholder_token, str): | 453 | if isinstance(args.placeholder_token, str): |
| 451 | args.placeholder_token = [args.placeholder_token] | 454 | args.placeholder_token = [args.placeholder_token] |
| 452 | 455 | ||
| 453 | if len(args.placeholder_token) == 0: | 456 | if len(args.placeholder_token) == 0: |
| 454 | args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] | 457 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
| 458 | |||
| 459 | if isinstance(args.initializer_token, str): | ||
| 460 | args.initializer_token = [args.initializer_token] * len(args.placeholder_token) | ||
| 461 | |||
| 462 | if len(args.initializer_token) == 0: | ||
| 463 | raise ValueError("You must specify --initializer_token") | ||
| 455 | 464 | ||
| 456 | if len(args.placeholder_token) != len(args.initializer_token): | 465 | if len(args.placeholder_token) != len(args.initializer_token): |
| 457 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 466 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") |
| 467 | |||
| 468 | if args.num_vectors is None: | ||
| 469 | args.num_vectors = 1 | ||
| 470 | |||
| 471 | if isinstance(args.num_vectors, int): | ||
| 472 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | ||
| 473 | |||
| 474 | if len(args.placeholder_token) != len(args.num_vectors): | ||
| 475 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | ||
| 458 | 476 | ||
| 459 | if isinstance(args.collection, str): | 477 | if isinstance(args.collection, str): |
| 460 | args.collection = [args.collection] | 478 | args.collection = [args.collection] |
| @@ -882,6 +900,18 @@ def main(): | |||
| 882 | finally: | 900 | finally: |
| 883 | pass | 901 | pass |
| 884 | 902 | ||
| 903 | def on_before_optimize(): | ||
| 904 | if accelerator.sync_gradients: | ||
| 905 | params_to_clip = [unet.parameters()] | ||
| 906 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs: | ||
| 907 | params_to_clip.append(text_encoder.parameters()) | ||
| 908 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) | ||
| 909 | |||
| 910 | @torch.no_grad() | ||
| 911 | def on_after_optimize(lr: float): | ||
| 912 | if not args.train_text_encoder: | ||
| 913 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | ||
| 914 | |||
| 885 | loop = partial( | 915 | loop = partial( |
| 886 | loss_step, | 916 | loss_step, |
| 887 | vae, | 917 | vae, |
| @@ -915,10 +945,12 @@ def main(): | |||
| 915 | loop, | 945 | loop, |
| 916 | on_train=tokenizer.train, | 946 | on_train=tokenizer.train, |
| 917 | on_eval=tokenizer.eval, | 947 | on_eval=tokenizer.eval, |
| 948 | on_before_optimize=on_before_optimize, | ||
| 949 | on_after_optimize=on_after_optimize, | ||
| 918 | ) | 950 | ) |
| 919 | lr_finder.run(end_lr=1e2) | 951 | lr_finder.run(num_epochs=100, end_lr=1e3) |
| 920 | 952 | ||
| 921 | plt.savefig(basepath.joinpath("lr.png")) | 953 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
| 922 | plt.close() | 954 | plt.close() |
| 923 | 955 | ||
| 924 | quit() | 956 | quit() |
| @@ -999,13 +1031,7 @@ def main(): | |||
| 999 | 1031 | ||
| 1000 | accelerator.backward(loss) | 1032 | accelerator.backward(loss) |
| 1001 | 1033 | ||
| 1002 | if accelerator.sync_gradients: | 1034 | on_before_optimize() |
| 1003 | params_to_clip = ( | ||
| 1004 | itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
| 1005 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | ||
| 1006 | else unet.parameters() | ||
| 1007 | ) | ||
| 1008 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
| 1009 | 1035 | ||
| 1010 | optimizer.step() | 1036 | optimizer.step() |
| 1011 | if not accelerator.optimizer_step_was_skipped: | 1037 | if not accelerator.optimizer_step_was_skipped: |
| @@ -1019,6 +1045,8 @@ def main(): | |||
| 1019 | 1045 | ||
| 1020 | # Checks if the accelerator has performed an optimization step behind the scenes | 1046 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 1021 | if accelerator.sync_gradients: | 1047 | if accelerator.sync_gradients: |
| 1048 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 1049 | |||
| 1022 | local_progress_bar.update(1) | 1050 | local_progress_bar.update(1) |
| 1023 | global_progress_bar.update(1) | 1051 | global_progress_bar.update(1) |
| 1024 | 1052 | ||
