summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py
index 45e730a..c452269 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -79,6 +79,11 @@ def parse_args():
79 help="Learning rate decay per cycle." 79 help="Learning rate decay per cycle."
80 ) 80 )
81 parser.add_argument( 81 parser.add_argument(
82 "--cycle_constant",
83 action="store_true",
84 help="Use constant LR on cycles > 1."
85 )
86 parser.add_argument(
82 "--placeholder_tokens", 87 "--placeholder_tokens",
83 type=str, 88 type=str,
84 nargs='*', 89 nargs='*',
@@ -926,6 +931,10 @@ def main():
926 print(f"------------ TI cycle {training_iter + 1} ------------") 931 print(f"------------ TI cycle {training_iter + 1} ------------")
927 print("") 932 print("")
928 933
934 if args.cycle_constant and training_iter == 1:
935 args.lr_scheduler = "constant"
936 args.lr_warmup_epochs = 0
937
929 optimizer = create_optimizer( 938 optimizer = create_optimizer(
930 text_encoder.text_model.embeddings.token_embedding.parameters(), 939 text_encoder.text_model.embeddings.token_embedding.parameters(),
931 lr=learning_rate, 940 lr=learning_rate,