diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-16 18:20:38 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-16 18:20:38 +0200 |
| commit | bd951892e300a0e21cb0e10fe261cb647ca160cd (patch) | |
| tree | 40f73452f6886be687cb6552588b114dc034fb00 | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-bd951892e300a0e21cb0e10fe261cb647ca160cd.tar.gz textual-inversion-diff-bd951892e300a0e21cb0e10fe261cb647ca160cd.tar.bz2 textual-inversion-diff-bd951892e300a0e21cb0e10fe261cb647ca160cd.zip | |
Added option to use constant LR on cycles > 1
| -rw-r--r-- | train_lora.py | 13 | ||||
| -rw-r--r-- | train_ti.py | 9 |
2 files changed, 20 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 5c78664..4d4c16a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -95,6 +95,11 @@ def parse_args(): | |||
| 95 | help="Learning rate decay per cycle." | 95 | help="Learning rate decay per cycle." |
| 96 | ) | 96 | ) |
| 97 | parser.add_argument( | 97 | parser.add_argument( |
| 98 | "--cycle_constant", | ||
| 99 | action="store_true", | ||
| 100 | help="Use constant LR on cycles > 1." | ||
| 101 | ) | ||
| 102 | parser.add_argument( | ||
| 98 | "--placeholder_tokens", | 103 | "--placeholder_tokens", |
| 99 | type=str, | 104 | type=str, |
| 100 | nargs='*', | 105 | nargs='*', |
| @@ -910,7 +915,6 @@ def main(): | |||
| 910 | 915 | ||
| 911 | create_lr_scheduler = partial( | 916 | create_lr_scheduler = partial( |
| 912 | get_scheduler, | 917 | get_scheduler, |
| 913 | args.lr_scheduler, | ||
| 914 | min_lr=args.lr_min_lr, | 918 | min_lr=args.lr_min_lr, |
| 915 | warmup_func=args.lr_warmup_func, | 919 | warmup_func=args.lr_warmup_func, |
| 916 | annealing_func=args.lr_annealing_func, | 920 | annealing_func=args.lr_annealing_func, |
| @@ -918,7 +922,6 @@ def main(): | |||
| 918 | annealing_exp=args.lr_annealing_exp, | 922 | annealing_exp=args.lr_annealing_exp, |
| 919 | cycles=args.lr_cycles, | 923 | cycles=args.lr_cycles, |
| 920 | end_lr=1e2, | 924 | end_lr=1e2, |
| 921 | warmup_epochs=args.lr_warmup_epochs, | ||
| 922 | mid_point=args.lr_mid_point, | 925 | mid_point=args.lr_mid_point, |
| 923 | ) | 926 | ) |
| 924 | 927 | ||
| @@ -971,6 +974,10 @@ def main(): | |||
| 971 | print(f"============ LoRA cycle {training_iter + 1} ============") | 974 | print(f"============ LoRA cycle {training_iter + 1} ============") |
| 972 | print("") | 975 | print("") |
| 973 | 976 | ||
| 977 | if args.cycle_constant and training_iter == 1: | ||
| 978 | args.lr_scheduler = "constant" | ||
| 979 | args.lr_warmup_epochs = 0 | ||
| 980 | |||
| 974 | params_to_optimize = [] | 981 | params_to_optimize = [] |
| 975 | 982 | ||
| 976 | if len(args.placeholder_tokens) != 0: | 983 | if len(args.placeholder_tokens) != 0: |
| @@ -1005,10 +1012,12 @@ def main(): | |||
| 1005 | lora_optimizer = create_optimizer(params_to_optimize) | 1012 | lora_optimizer = create_optimizer(params_to_optimize) |
| 1006 | 1013 | ||
| 1007 | lora_lr_scheduler = create_lr_scheduler( | 1014 | lora_lr_scheduler = create_lr_scheduler( |
| 1015 | args.lr_scheduler, | ||
| 1008 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1016 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1009 | optimizer=lora_optimizer, | 1017 | optimizer=lora_optimizer, |
| 1010 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | 1018 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
| 1011 | train_epochs=num_train_epochs, | 1019 | train_epochs=num_train_epochs, |
| 1020 | warmup_epochs=args.lr_warmup_epochs, | ||
| 1012 | ) | 1021 | ) |
| 1013 | 1022 | ||
| 1014 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" | 1023 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" |
diff --git a/train_ti.py b/train_ti.py index 45e730a..c452269 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -79,6 +79,11 @@ def parse_args(): | |||
| 79 | help="Learning rate decay per cycle." | 79 | help="Learning rate decay per cycle." |
| 80 | ) | 80 | ) |
| 81 | parser.add_argument( | 81 | parser.add_argument( |
| 82 | "--cycle_constant", | ||
| 83 | action="store_true", | ||
| 84 | help="Use constant LR on cycles > 1." | ||
| 85 | ) | ||
| 86 | parser.add_argument( | ||
| 82 | "--placeholder_tokens", | 87 | "--placeholder_tokens", |
| 83 | type=str, | 88 | type=str, |
| 84 | nargs='*', | 89 | nargs='*', |
| @@ -926,6 +931,10 @@ def main(): | |||
| 926 | print(f"------------ TI cycle {training_iter + 1} ------------") | 931 | print(f"------------ TI cycle {training_iter + 1} ------------") |
| 927 | print("") | 932 | print("") |
| 928 | 933 | ||
| 934 | if args.cycle_constant and training_iter == 1: | ||
| 935 | args.lr_scheduler = "constant" | ||
| 936 | args.lr_warmup_epochs = 0 | ||
| 937 | |||
| 929 | optimizer = create_optimizer( | 938 | optimizer = create_optimizer( |
| 930 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 939 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 931 | lr=learning_rate, | 940 | lr=learning_rate, |
