diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index 45e730a..c452269 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -79,6 +79,11 @@ def parse_args(): | |||
| 79 | help="Learning rate decay per cycle." | 79 | help="Learning rate decay per cycle." |
| 80 | ) | 80 | ) |
| 81 | parser.add_argument( | 81 | parser.add_argument( |
| 82 | "--cycle_constant", | ||
| 83 | action="store_true", | ||
| 84 | help="Use constant LR on cycles > 1." | ||
| 85 | ) | ||
| 86 | parser.add_argument( | ||
| 82 | "--placeholder_tokens", | 87 | "--placeholder_tokens", |
| 83 | type=str, | 88 | type=str, |
| 84 | nargs='*', | 89 | nargs='*', |
| @@ -926,6 +931,10 @@ def main(): | |||
| 926 | print(f"------------ TI cycle {training_iter + 1} ------------") | 931 | print(f"------------ TI cycle {training_iter + 1} ------------") |
| 927 | print("") | 932 | print("") |
| 928 | 933 | ||
| 934 | if args.cycle_constant and training_iter == 1: | ||
| 935 | args.lr_scheduler = "constant" | ||
| 936 | args.lr_warmup_epochs = 0 | ||
| 937 | |||
| 929 | optimizer = create_optimizer( | 938 | optimizer = create_optimizer( |
| 930 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 939 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 931 | lr=learning_rate, | 940 | lr=learning_rate, |
