summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-16 18:20:38 +0200
committerVolpeon <git@volpeon.ink>2023-04-16 18:20:38 +0200
commitbd951892e300a0e21cb0e10fe261cb647ca160cd (patch)
tree40f73452f6886be687cb6552588b114dc034fb00
parentFix (diff)
downloadtextual-inversion-diff-bd951892e300a0e21cb0e10fe261cb647ca160cd.tar.gz
textual-inversion-diff-bd951892e300a0e21cb0e10fe261cb647ca160cd.tar.bz2
textual-inversion-diff-bd951892e300a0e21cb0e10fe261cb647ca160cd.zip
Added option to use constant LR on cycles > 1
-rw-r--r--train_lora.py13
-rw-r--r--train_ti.py9
2 files changed, 20 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py
index 5c78664..4d4c16a 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -95,6 +95,11 @@ def parse_args():
95 help="Learning rate decay per cycle." 95 help="Learning rate decay per cycle."
96 ) 96 )
97 parser.add_argument( 97 parser.add_argument(
98 "--cycle_constant",
99 action="store_true",
100 help="Use constant LR on cycles > 1."
101 )
102 parser.add_argument(
98 "--placeholder_tokens", 103 "--placeholder_tokens",
99 type=str, 104 type=str,
100 nargs='*', 105 nargs='*',
@@ -910,7 +915,6 @@ def main():
910 915
911 create_lr_scheduler = partial( 916 create_lr_scheduler = partial(
912 get_scheduler, 917 get_scheduler,
913 args.lr_scheduler,
914 min_lr=args.lr_min_lr, 918 min_lr=args.lr_min_lr,
915 warmup_func=args.lr_warmup_func, 919 warmup_func=args.lr_warmup_func,
916 annealing_func=args.lr_annealing_func, 920 annealing_func=args.lr_annealing_func,
@@ -918,7 +922,6 @@ def main():
918 annealing_exp=args.lr_annealing_exp, 922 annealing_exp=args.lr_annealing_exp,
919 cycles=args.lr_cycles, 923 cycles=args.lr_cycles,
920 end_lr=1e2, 924 end_lr=1e2,
921 warmup_epochs=args.lr_warmup_epochs,
922 mid_point=args.lr_mid_point, 925 mid_point=args.lr_mid_point,
923 ) 926 )
924 927
@@ -971,6 +974,10 @@ def main():
971 print(f"============ LoRA cycle {training_iter + 1} ============") 974 print(f"============ LoRA cycle {training_iter + 1} ============")
972 print("") 975 print("")
973 976
977 if args.cycle_constant and training_iter == 1:
978 args.lr_scheduler = "constant"
979 args.lr_warmup_epochs = 0
980
974 params_to_optimize = [] 981 params_to_optimize = []
975 982
976 if len(args.placeholder_tokens) != 0: 983 if len(args.placeholder_tokens) != 0:
@@ -1005,10 +1012,12 @@ def main():
1005 lora_optimizer = create_optimizer(params_to_optimize) 1012 lora_optimizer = create_optimizer(params_to_optimize)
1006 1013
1007 lora_lr_scheduler = create_lr_scheduler( 1014 lora_lr_scheduler = create_lr_scheduler(
1015 args.lr_scheduler,
1008 gradient_accumulation_steps=args.gradient_accumulation_steps, 1016 gradient_accumulation_steps=args.gradient_accumulation_steps,
1009 optimizer=lora_optimizer, 1017 optimizer=lora_optimizer,
1010 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), 1018 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader),
1011 train_epochs=num_train_epochs, 1019 train_epochs=num_train_epochs,
1020 warmup_epochs=args.lr_warmup_epochs,
1012 ) 1021 )
1013 1022
1014 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" 1023 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}"
diff --git a/train_ti.py b/train_ti.py
index 45e730a..c452269 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -79,6 +79,11 @@ def parse_args():
79 help="Learning rate decay per cycle." 79 help="Learning rate decay per cycle."
80 ) 80 )
81 parser.add_argument( 81 parser.add_argument(
82 "--cycle_constant",
83 action="store_true",
84 help="Use constant LR on cycles > 1."
85 )
86 parser.add_argument(
82 "--placeholder_tokens", 87 "--placeholder_tokens",
83 type=str, 88 type=str,
84 nargs='*', 89 nargs='*',
@@ -926,6 +931,10 @@ def main():
926 print(f"------------ TI cycle {training_iter + 1} ------------") 931 print(f"------------ TI cycle {training_iter + 1} ------------")
927 print("") 932 print("")
928 933
934 if args.cycle_constant and training_iter == 1:
935 args.lr_scheduler = "constant"
936 args.lr_warmup_epochs = 0
937
929 optimizer = create_optimizer( 938 optimizer = create_optimizer(
930 text_encoder.text_model.embeddings.token_embedding.parameters(), 939 text_encoder.text_model.embeddings.token_embedding.parameters(),
931 lr=learning_rate, 940 lr=learning_rate,