diff options
author | Volpeon <git@volpeon.ink> | 2023-01-04 12:18:07 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-04 12:18:07 +0100 |
commit | 01fee7d37a116265edb0f16e0b2f75d2116eb9f6 (patch) | |
tree | 6389f385191247fb3639900da0d29a3064259cb7 /train_dreambooth.py | |
parent | Better eval generator (diff) | |
download | textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.gz textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.bz2 textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.zip |
Various updates
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 5e6e35d..2e0696b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -269,6 +269,12 @@ def parse_args(): | |||
269 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 269 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
270 | ) | 270 | ) |
271 | parser.add_argument( | 271 | parser.add_argument( |
272 | "--lr_min_lr", | ||
273 | type=float, | ||
274 | default=None, | ||
275 | help="Minimum learning rate in the lr scheduler." | ||
276 | ) | ||
277 | parser.add_argument( | ||
272 | "--use_ema", | 278 | "--use_ema", |
273 | action="store_true", | 279 | action="store_true", |
274 | default=True, | 280 | default=True, |
@@ -799,6 +805,7 @@ def main(): | |||
799 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | 805 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps |
800 | 806 | ||
801 | if args.lr_scheduler == "one_cycle": | 807 | if args.lr_scheduler == "one_cycle": |
808 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
802 | lr_scheduler = get_one_cycle_schedule( | 809 | lr_scheduler = get_one_cycle_schedule( |
803 | optimizer=optimizer, | 810 | optimizer=optimizer, |
804 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 811 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
@@ -806,6 +813,7 @@ def main(): | |||
806 | annealing=args.lr_annealing_func, | 813 | annealing=args.lr_annealing_func, |
807 | warmup_exp=args.lr_warmup_exp, | 814 | warmup_exp=args.lr_warmup_exp, |
808 | annealing_exp=args.lr_annealing_exp, | 815 | annealing_exp=args.lr_annealing_exp, |
816 | min_lr=lr_min_lr, | ||
809 | ) | 817 | ) |
810 | elif args.lr_scheduler == "cosine_with_restarts": | 818 | elif args.lr_scheduler == "cosine_with_restarts": |
811 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 819 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |