diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index df8d443..35be74c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -169,6 +169,11 @@ def parse_args(): | |||
| 169 | help="Tag dropout probability.", | 169 | help="Tag dropout probability.", |
| 170 | ) | 170 | ) |
| 171 | parser.add_argument( | 171 | parser.add_argument( |
| 172 | "--tag_shuffle", | ||
| 173 | type="store_true", | ||
| 174 | help="Shuffle tags.", | ||
| 175 | ) | ||
| 176 | parser.add_argument( | ||
| 172 | "--vector_dropout", | 177 | "--vector_dropout", |
| 173 | type=int, | 178 | type=int, |
| 174 | default=0, | 179 | default=0, |
| @@ -395,7 +400,7 @@ def parse_args(): | |||
| 395 | parser.add_argument( | 400 | parser.add_argument( |
| 396 | "--sample_steps", | 401 | "--sample_steps", |
| 397 | type=int, | 402 | type=int, |
| 398 | default=15, | 403 | default=20, |
| 399 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 404 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 400 | ) | 405 | ) |
| 401 | parser.add_argument( | 406 | parser.add_argument( |
| @@ -745,6 +750,7 @@ def main(): | |||
| 745 | bucket_step_size=args.bucket_step_size, | 750 | bucket_step_size=args.bucket_step_size, |
| 746 | bucket_max_pixels=args.bucket_max_pixels, | 751 | bucket_max_pixels=args.bucket_max_pixels, |
| 747 | dropout=args.tag_dropout, | 752 | dropout=args.tag_dropout, |
| 753 | shuffle=args.tag_shuffle, | ||
| 748 | template_key=args.train_data_template, | 754 | template_key=args.train_data_template, |
| 749 | valid_set_size=args.valid_set_size, | 755 | valid_set_size=args.valid_set_size, |
| 750 | valid_set_repeat=args.valid_set_repeat, | 756 | valid_set_repeat=args.valid_set_repeat, |
| @@ -860,6 +866,12 @@ def main(): | |||
| 860 | finally: | 866 | finally: |
| 861 | pass | 867 | pass |
| 862 | 868 | ||
| 869 | def on_clip(): | ||
| 870 | accelerator.clip_grad_norm_( | ||
| 871 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 872 | args.max_grad_norm | ||
| 873 | ) | ||
| 874 | |||
| 863 | loop = partial( | 875 | loop = partial( |
| 864 | run_model, | 876 | run_model, |
| 865 | vae, | 877 | vae, |
| @@ -894,8 +906,9 @@ def main(): | |||
| 894 | loop, | 906 | loop, |
| 895 | on_train=on_train, | 907 | on_train=on_train, |
| 896 | on_eval=on_eval, | 908 | on_eval=on_eval, |
| 909 | on_clip=on_clip, | ||
| 897 | ) | 910 | ) |
| 898 | lr_finder.run(num_epochs=200, end_lr=1e3) | 911 | lr_finder.run(num_epochs=100, end_lr=1e3) |
| 899 | 912 | ||
| 900 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 913 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
| 901 | plt.close() | 914 | plt.close() |
| @@ -975,10 +988,7 @@ def main(): | |||
| 975 | accelerator.backward(loss) | 988 | accelerator.backward(loss) |
| 976 | 989 | ||
| 977 | if accelerator.sync_gradients: | 990 | if accelerator.sync_gradients: |
| 978 | accelerator.clip_grad_norm_( | 991 | on_clip() |
| 979 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 980 | args.max_grad_norm | ||
| 981 | ) | ||
| 982 | 992 | ||
| 983 | optimizer.step() | 993 | optimizer.step() |
| 984 | if not accelerator.optimizer_step_was_skipped: | 994 | if not accelerator.optimizer_step_was_skipped: |
