diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-04 12:18:07 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-04 12:18:07 +0100 |
| commit | 01fee7d37a116265edb0f16e0b2f75d2116eb9f6 (patch) | |
| tree | 6389f385191247fb3639900da0d29a3064259cb7 /train_ti.py | |
| parent | Better eval generator (diff) | |
| download | textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.gz textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.bz2 textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.zip | |
Various updates
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index 6f116c3..1b60f64 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -260,6 +260,12 @@ def parse_args(): | |||
| 260 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 260 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
| 261 | ) | 261 | ) |
| 262 | parser.add_argument( | 262 | parser.add_argument( |
| 263 | "--lr_min_lr", | ||
| 264 | type=float, | ||
| 265 | default=None, | ||
| 266 | help="Minimum learning rate in the lr scheduler." | ||
| 267 | ) | ||
| 268 | parser.add_argument( | ||
| 263 | "--use_8bit_adam", | 269 | "--use_8bit_adam", |
| 264 | action="store_true", | 270 | action="store_true", |
| 265 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 271 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| @@ -744,6 +750,7 @@ def main(): | |||
| 744 | if args.find_lr: | 750 | if args.find_lr: |
| 745 | lr_scheduler = None | 751 | lr_scheduler = None |
| 746 | elif args.lr_scheduler == "one_cycle": | 752 | elif args.lr_scheduler == "one_cycle": |
| 753 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
| 747 | lr_scheduler = get_one_cycle_schedule( | 754 | lr_scheduler = get_one_cycle_schedule( |
| 748 | optimizer=optimizer, | 755 | optimizer=optimizer, |
| 749 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 756 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| @@ -751,6 +758,7 @@ def main(): | |||
| 751 | annealing=args.lr_annealing_func, | 758 | annealing=args.lr_annealing_func, |
| 752 | warmup_exp=args.lr_warmup_exp, | 759 | warmup_exp=args.lr_warmup_exp, |
| 753 | annealing_exp=args.lr_annealing_exp, | 760 | annealing_exp=args.lr_annealing_exp, |
| 761 | min_lr=lr_min_lr, | ||
| 754 | ) | 762 | ) |
| 755 | elif args.lr_scheduler == "cosine_with_restarts": | 763 | elif args.lr_scheduler == "cosine_with_restarts": |
| 756 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 764 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
