diff options
author | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-12 13:50:22 +0100 |
commit | f963d4cba5c4c6575d77be80621a40b615603ca3 (patch) | |
tree | d60ecb8c99534a12cef8070b0cf5a77eecc1c8d1 /train_dreambooth.py | |
parent | Fixed TI decay (diff) | |
download | textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.gz textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.tar.bz2 textual-inversion-diff-f963d4cba5c4c6575d77be80621a40b615603ca3.zip |
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 56 |
1 files changed, 42 insertions, 14 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 73d9935..ebcf802 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -87,6 +87,12 @@ def parse_args(): | |||
87 | help="A token to use as initializer word." | 87 | help="A token to use as initializer word." |
88 | ) | 88 | ) |
89 | parser.add_argument( | 89 | parser.add_argument( |
90 | "--num_vectors", | ||
91 | type=int, | ||
92 | nargs='*', | ||
93 | help="Number of vectors per embedding." | ||
94 | ) | ||
95 | parser.add_argument( | ||
90 | "--exclude_collections", | 96 | "--exclude_collections", |
91 | type=str, | 97 | type=str, |
92 | nargs='*', | 98 | nargs='*', |
@@ -444,17 +450,29 @@ def parse_args(): | |||
444 | if args.project is None: | 450 | if args.project is None: |
445 | raise ValueError("You must specify --project") | 451 | raise ValueError("You must specify --project") |
446 | 452 | ||
447 | if isinstance(args.initializer_token, str): | ||
448 | args.initializer_token = [args.initializer_token] | ||
449 | |||
450 | if isinstance(args.placeholder_token, str): | 453 | if isinstance(args.placeholder_token, str): |
451 | args.placeholder_token = [args.placeholder_token] | 454 | args.placeholder_token = [args.placeholder_token] |
452 | 455 | ||
453 | if len(args.placeholder_token) == 0: | 456 | if len(args.placeholder_token) == 0: |
454 | args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] | 457 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
458 | |||
459 | if isinstance(args.initializer_token, str): | ||
460 | args.initializer_token = [args.initializer_token] * len(args.placeholder_token) | ||
461 | |||
462 | if len(args.initializer_token) == 0: | ||
463 | raise ValueError("You must specify --initializer_token") | ||
455 | 464 | ||
456 | if len(args.placeholder_token) != len(args.initializer_token): | 465 | if len(args.placeholder_token) != len(args.initializer_token): |
457 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 466 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") |
467 | |||
468 | if args.num_vectors is None: | ||
469 | args.num_vectors = 1 | ||
470 | |||
471 | if isinstance(args.num_vectors, int): | ||
472 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | ||
473 | |||
474 | if len(args.placeholder_token) != len(args.num_vectors): | ||
475 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | ||
458 | 476 | ||
459 | if isinstance(args.collection, str): | 477 | if isinstance(args.collection, str): |
460 | args.collection = [args.collection] | 478 | args.collection = [args.collection] |
@@ -882,6 +900,18 @@ def main(): | |||
882 | finally: | 900 | finally: |
883 | pass | 901 | pass |
884 | 902 | ||
903 | def on_before_optimize(): | ||
904 | if accelerator.sync_gradients: | ||
905 | params_to_clip = [unet.parameters()] | ||
906 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs: | ||
907 | params_to_clip.append(text_encoder.parameters()) | ||
908 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) | ||
909 | |||
910 | @torch.no_grad() | ||
911 | def on_after_optimize(lr: float): | ||
912 | if not args.train_text_encoder: | ||
913 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | ||
914 | |||
885 | loop = partial( | 915 | loop = partial( |
886 | loss_step, | 916 | loss_step, |
887 | vae, | 917 | vae, |
@@ -915,10 +945,12 @@ def main(): | |||
915 | loop, | 945 | loop, |
916 | on_train=tokenizer.train, | 946 | on_train=tokenizer.train, |
917 | on_eval=tokenizer.eval, | 947 | on_eval=tokenizer.eval, |
948 | on_before_optimize=on_before_optimize, | ||
949 | on_after_optimize=on_after_optimize, | ||
918 | ) | 950 | ) |
919 | lr_finder.run(end_lr=1e2) | 951 | lr_finder.run(num_epochs=100, end_lr=1e3) |
920 | 952 | ||
921 | plt.savefig(basepath.joinpath("lr.png")) | 953 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
922 | plt.close() | 954 | plt.close() |
923 | 955 | ||
924 | quit() | 956 | quit() |
@@ -999,13 +1031,7 @@ def main(): | |||
999 | 1031 | ||
1000 | accelerator.backward(loss) | 1032 | accelerator.backward(loss) |
1001 | 1033 | ||
1002 | if accelerator.sync_gradients: | 1034 | on_before_optimize() |
1003 | params_to_clip = ( | ||
1004 | itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
1005 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | ||
1006 | else unet.parameters() | ||
1007 | ) | ||
1008 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
1009 | 1035 | ||
1010 | optimizer.step() | 1036 | optimizer.step() |
1011 | if not accelerator.optimizer_step_was_skipped: | 1037 | if not accelerator.optimizer_step_was_skipped: |
@@ -1019,6 +1045,8 @@ def main(): | |||
1019 | 1045 | ||
1020 | # Checks if the accelerator has performed an optimization step behind the scenes | 1046 | # Checks if the accelerator has performed an optimization step behind the scenes |
1021 | if accelerator.sync_gradients: | 1047 | if accelerator.sync_gradients: |
1048 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
1049 | |||
1022 | local_progress_bar.update(1) | 1050 | local_progress_bar.update(1) |
1023 | global_progress_bar.update(1) | 1051 | global_progress_bar.update(1) |
1024 | 1052 | ||