summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-18 18:08:32 +0200
committerVolpeon <git@volpeon.ink>2022-10-18 18:08:32 +0200
commit2ddfbd65e482fa2361e8ba41b657656f825c9143 (patch)
tree41cc82e23d82dd620c81f2715a50969b832e9bda /dreambooth_plus.py
parentImproved prompt handling (diff)
downloadtextual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.gz
textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.bz2
textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.zip
Adapted other scripts for new prompt processing
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index fa3a22b..06ff45b 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -125,7 +125,7 @@ def parse_args():
125 parser.add_argument( 125 parser.add_argument(
126 "--max_train_steps", 126 "--max_train_steps",
127 type=int, 127 type=int,
128 default=1400, 128 default=2400,
129 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 129 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
130 ) 130 )
131 parser.add_argument( 131 parser.add_argument(
@@ -752,8 +752,8 @@ def main():
752 optimizer=optimizer, 752 optimizer=optimizer,
753 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 753 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
754 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 754 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
755 num_cycles=args.lr_cycles or math.ceil( 755 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
756 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), 756 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
757 ) 757 )
758 else: 758 else:
759 lr_scheduler = get_scheduler( 759 lr_scheduler = get_scheduler(