diff options
author | Volpeon <git@volpeon.ink> | 2023-01-10 09:22:02 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-10 09:22:02 +0100 |
commit | 33e7d2ed37e32657ca94d92815043026c4cea7c0 (patch) | |
tree | 0af4d6ad0ba92a168e3ec17675147c76afe1baf0 /train_ti.py | |
parent | Enable buckets for validation, fixed vaildation repeat arg (diff) | |
download | textual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.tar.gz textual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.tar.bz2 textual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.zip |
Added arg to disable tag shuffling
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index df8d443..35be74c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -169,6 +169,11 @@ def parse_args(): | |||
169 | help="Tag dropout probability.", | 169 | help="Tag dropout probability.", |
170 | ) | 170 | ) |
171 | parser.add_argument( | 171 | parser.add_argument( |
172 | "--tag_shuffle", | ||
173 | type="store_true", | ||
174 | help="Shuffle tags.", | ||
175 | ) | ||
176 | parser.add_argument( | ||
172 | "--vector_dropout", | 177 | "--vector_dropout", |
173 | type=int, | 178 | type=int, |
174 | default=0, | 179 | default=0, |
@@ -395,7 +400,7 @@ def parse_args(): | |||
395 | parser.add_argument( | 400 | parser.add_argument( |
396 | "--sample_steps", | 401 | "--sample_steps", |
397 | type=int, | 402 | type=int, |
398 | default=15, | 403 | default=20, |
399 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 404 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
400 | ) | 405 | ) |
401 | parser.add_argument( | 406 | parser.add_argument( |
@@ -745,6 +750,7 @@ def main(): | |||
745 | bucket_step_size=args.bucket_step_size, | 750 | bucket_step_size=args.bucket_step_size, |
746 | bucket_max_pixels=args.bucket_max_pixels, | 751 | bucket_max_pixels=args.bucket_max_pixels, |
747 | dropout=args.tag_dropout, | 752 | dropout=args.tag_dropout, |
753 | shuffle=args.tag_shuffle, | ||
748 | template_key=args.train_data_template, | 754 | template_key=args.train_data_template, |
749 | valid_set_size=args.valid_set_size, | 755 | valid_set_size=args.valid_set_size, |
750 | valid_set_repeat=args.valid_set_repeat, | 756 | valid_set_repeat=args.valid_set_repeat, |
@@ -860,6 +866,12 @@ def main(): | |||
860 | finally: | 866 | finally: |
861 | pass | 867 | pass |
862 | 868 | ||
869 | def on_clip(): | ||
870 | accelerator.clip_grad_norm_( | ||
871 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
872 | args.max_grad_norm | ||
873 | ) | ||
874 | |||
863 | loop = partial( | 875 | loop = partial( |
864 | run_model, | 876 | run_model, |
865 | vae, | 877 | vae, |
@@ -894,8 +906,9 @@ def main(): | |||
894 | loop, | 906 | loop, |
895 | on_train=on_train, | 907 | on_train=on_train, |
896 | on_eval=on_eval, | 908 | on_eval=on_eval, |
909 | on_clip=on_clip, | ||
897 | ) | 910 | ) |
898 | lr_finder.run(num_epochs=200, end_lr=1e3) | 911 | lr_finder.run(num_epochs=100, end_lr=1e3) |
899 | 912 | ||
900 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 913 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
901 | plt.close() | 914 | plt.close() |
@@ -975,10 +988,7 @@ def main(): | |||
975 | accelerator.backward(loss) | 988 | accelerator.backward(loss) |
976 | 989 | ||
977 | if accelerator.sync_gradients: | 990 | if accelerator.sync_gradients: |
978 | accelerator.clip_grad_norm_( | 991 | on_clip() |
979 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
980 | args.max_grad_norm | ||
981 | ) | ||
982 | 992 | ||
983 | optimizer.step() | 993 | optimizer.step() |
984 | if not accelerator.optimizer_step_was_skipped: | 994 | if not accelerator.optimizer_step_was_skipped: |