diff options
author | Volpeon <git@volpeon.ink> | 2023-01-04 12:18:07 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-04 12:18:07 +0100 |
commit | 01fee7d37a116265edb0f16e0b2f75d2116eb9f6 (patch) | |
tree | 6389f385191247fb3639900da0d29a3064259cb7 /train_ti.py | |
parent | Better eval generator (diff) | |
download | textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.gz textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.bz2 textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.zip |
Various updates
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index 6f116c3..1b60f64 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -260,6 +260,12 @@ def parse_args(): | |||
260 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 260 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
261 | ) | 261 | ) |
262 | parser.add_argument( | 262 | parser.add_argument( |
263 | "--lr_min_lr", | ||
264 | type=float, | ||
265 | default=None, | ||
266 | help="Minimum learning rate in the lr scheduler." | ||
267 | ) | ||
268 | parser.add_argument( | ||
263 | "--use_8bit_adam", | 269 | "--use_8bit_adam", |
264 | action="store_true", | 270 | action="store_true", |
265 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 271 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
@@ -744,6 +750,7 @@ def main(): | |||
744 | if args.find_lr: | 750 | if args.find_lr: |
745 | lr_scheduler = None | 751 | lr_scheduler = None |
746 | elif args.lr_scheduler == "one_cycle": | 752 | elif args.lr_scheduler == "one_cycle": |
753 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
747 | lr_scheduler = get_one_cycle_schedule( | 754 | lr_scheduler = get_one_cycle_schedule( |
748 | optimizer=optimizer, | 755 | optimizer=optimizer, |
749 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 756 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
@@ -751,6 +758,7 @@ def main(): | |||
751 | annealing=args.lr_annealing_func, | 758 | annealing=args.lr_annealing_func, |
752 | warmup_exp=args.lr_warmup_exp, | 759 | warmup_exp=args.lr_warmup_exp, |
753 | annealing_exp=args.lr_annealing_exp, | 760 | annealing_exp=args.lr_annealing_exp, |
761 | min_lr=lr_min_lr, | ||
754 | ) | 762 | ) |
755 | elif args.lr_scheduler == "cosine_with_restarts": | 763 | elif args.lr_scheduler == "cosine_with_restarts": |
756 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 764 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |