diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index 45e730a..c452269 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -79,6 +79,11 @@ def parse_args(): | |||
79 | help="Learning rate decay per cycle." | 79 | help="Learning rate decay per cycle." |
80 | ) | 80 | ) |
81 | parser.add_argument( | 81 | parser.add_argument( |
82 | "--cycle_constant", | ||
83 | action="store_true", | ||
84 | help="Use constant LR on cycles > 1." | ||
85 | ) | ||
86 | parser.add_argument( | ||
82 | "--placeholder_tokens", | 87 | "--placeholder_tokens", |
83 | type=str, | 88 | type=str, |
84 | nargs='*', | 89 | nargs='*', |
@@ -926,6 +931,10 @@ def main(): | |||
926 | print(f"------------ TI cycle {training_iter + 1} ------------") | 931 | print(f"------------ TI cycle {training_iter + 1} ------------") |
927 | print("") | 932 | print("") |
928 | 933 | ||
934 | if args.cycle_constant and training_iter == 1: | ||
935 | args.lr_scheduler = "constant" | ||
936 | args.lr_warmup_epochs = 0 | ||
937 | |||
929 | optimizer = create_optimizer( | 938 | optimizer = create_optimizer( |
930 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 939 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
931 | lr=learning_rate, | 940 | lr=learning_rate, |