summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 12:18:07 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 12:18:07 +0100
commit01fee7d37a116265edb0f16e0b2f75d2116eb9f6 (patch)
tree6389f385191247fb3639900da0d29a3064259cb7 /train_ti.py
parentBetter eval generator (diff)
downloadtextual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.gz
textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.bz2
textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.zip
Various updates
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py
index 6f116c3..1b60f64 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -260,6 +260,12 @@ def parse_args():
260 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 260 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
261 ) 261 )
262 parser.add_argument( 262 parser.add_argument(
263 "--lr_min_lr",
264 type=float,
265 default=None,
266 help="Minimum learning rate in the lr scheduler."
267 )
268 parser.add_argument(
263 "--use_8bit_adam", 269 "--use_8bit_adam",
264 action="store_true", 270 action="store_true",
265 help="Whether or not to use 8-bit Adam from bitsandbytes." 271 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -744,6 +750,7 @@ def main():
744 if args.find_lr: 750 if args.find_lr:
745 lr_scheduler = None 751 lr_scheduler = None
746 elif args.lr_scheduler == "one_cycle": 752 elif args.lr_scheduler == "one_cycle":
753 lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate
747 lr_scheduler = get_one_cycle_schedule( 754 lr_scheduler = get_one_cycle_schedule(
748 optimizer=optimizer, 755 optimizer=optimizer,
749 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 756 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
@@ -751,6 +758,7 @@ def main():
751 annealing=args.lr_annealing_func, 758 annealing=args.lr_annealing_func,
752 warmup_exp=args.lr_warmup_exp, 759 warmup_exp=args.lr_warmup_exp,
753 annealing_exp=args.lr_annealing_exp, 760 annealing_exp=args.lr_annealing_exp,
761 min_lr=lr_min_lr,
754 ) 762 )
755 elif args.lr_scheduler == "cosine_with_restarts": 763 elif args.lr_scheduler == "cosine_with_restarts":
756 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 764 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(