summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 12:18:07 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 12:18:07 +0100
commit01fee7d37a116265edb0f16e0b2f75d2116eb9f6 (patch)
tree6389f385191247fb3639900da0d29a3064259cb7 /train_dreambooth.py
parentBetter eval generator (diff)
downloadtextual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.gz
textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.bz2
textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.zip
Various updates
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 5e6e35d..2e0696b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -269,6 +269,12 @@ def parse_args():
269 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 269 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
270 ) 270 )
271 parser.add_argument( 271 parser.add_argument(
272 "--lr_min_lr",
273 type=float,
274 default=None,
275 help="Minimum learning rate in the lr scheduler."
276 )
277 parser.add_argument(
272 "--use_ema", 278 "--use_ema",
273 action="store_true", 279 action="store_true",
274 default=True, 280 default=True,
@@ -799,6 +805,7 @@ def main():
799 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps 805 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
800 806
801 if args.lr_scheduler == "one_cycle": 807 if args.lr_scheduler == "one_cycle":
808 lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate
802 lr_scheduler = get_one_cycle_schedule( 809 lr_scheduler = get_one_cycle_schedule(
803 optimizer=optimizer, 810 optimizer=optimizer,
804 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 811 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
@@ -806,6 +813,7 @@ def main():
806 annealing=args.lr_annealing_func, 813 annealing=args.lr_annealing_func,
807 warmup_exp=args.lr_warmup_exp, 814 warmup_exp=args.lr_warmup_exp,
808 annealing_exp=args.lr_annealing_exp, 815 annealing_exp=args.lr_annealing_exp,
816 min_lr=lr_min_lr,
809 ) 817 )
810 elif args.lr_scheduler == "cosine_with_restarts": 818 elif args.lr_scheduler == "cosine_with_restarts":
811 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 819 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(