summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py22
1 files changed, 16 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index df8d443..35be74c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -169,6 +169,11 @@ def parse_args():
169 help="Tag dropout probability.", 169 help="Tag dropout probability.",
170 ) 170 )
171 parser.add_argument( 171 parser.add_argument(
172 "--tag_shuffle",
173 type="store_true",
174 help="Shuffle tags.",
175 )
176 parser.add_argument(
172 "--vector_dropout", 177 "--vector_dropout",
173 type=int, 178 type=int,
174 default=0, 179 default=0,
@@ -395,7 +400,7 @@ def parse_args():
395 parser.add_argument( 400 parser.add_argument(
396 "--sample_steps", 401 "--sample_steps",
397 type=int, 402 type=int,
398 default=15, 403 default=20,
399 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 404 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
400 ) 405 )
401 parser.add_argument( 406 parser.add_argument(
@@ -745,6 +750,7 @@ def main():
745 bucket_step_size=args.bucket_step_size, 750 bucket_step_size=args.bucket_step_size,
746 bucket_max_pixels=args.bucket_max_pixels, 751 bucket_max_pixels=args.bucket_max_pixels,
747 dropout=args.tag_dropout, 752 dropout=args.tag_dropout,
753 shuffle=args.tag_shuffle,
748 template_key=args.train_data_template, 754 template_key=args.train_data_template,
749 valid_set_size=args.valid_set_size, 755 valid_set_size=args.valid_set_size,
750 valid_set_repeat=args.valid_set_repeat, 756 valid_set_repeat=args.valid_set_repeat,
@@ -860,6 +866,12 @@ def main():
860 finally: 866 finally:
861 pass 867 pass
862 868
869 def on_clip():
870 accelerator.clip_grad_norm_(
871 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
872 args.max_grad_norm
873 )
874
863 loop = partial( 875 loop = partial(
864 run_model, 876 run_model,
865 vae, 877 vae,
@@ -894,8 +906,9 @@ def main():
894 loop, 906 loop,
895 on_train=on_train, 907 on_train=on_train,
896 on_eval=on_eval, 908 on_eval=on_eval,
909 on_clip=on_clip,
897 ) 910 )
898 lr_finder.run(num_epochs=200, end_lr=1e3) 911 lr_finder.run(num_epochs=100, end_lr=1e3)
899 912
900 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 913 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
901 plt.close() 914 plt.close()
@@ -975,10 +988,7 @@ def main():
975 accelerator.backward(loss) 988 accelerator.backward(loss)
976 989
977 if accelerator.sync_gradients: 990 if accelerator.sync_gradients:
978 accelerator.clip_grad_norm_( 991 on_clip()
979 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
980 args.max_grad_norm
981 )
982 992
983 optimizer.step() 993 optimizer.step()
984 if not accelerator.optimizer_step_was_skipped: 994 if not accelerator.optimizer_step_was_skipped: