summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-10 09:22:02 +0100
committerVolpeon <git@volpeon.ink>2023-01-10 09:22:02 +0100
commit33e7d2ed37e32657ca94d92815043026c4cea7c0 (patch)
tree0af4d6ad0ba92a168e3ec17675147c76afe1baf0 /train_dreambooth.py
parentEnable buckets for validation, fixed vaildation repeat arg (diff)
downloadtextual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.tar.gz
textual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.tar.bz2
textual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.zip
Added arg to disable tag shuffling
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1a1f516..48a513c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -133,6 +133,12 @@ def parse_args():
133 help="Tag dropout probability.", 133 help="Tag dropout probability.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
136 "--tag_shuffle",
137 type="store_true",
138 default=True,
139 help="Shuffle tags.",
140 )
141 parser.add_argument(
136 "--vector_dropout", 142 "--vector_dropout",
137 type=int, 143 type=int,
138 default=0, 144 default=0,
@@ -398,7 +404,7 @@ def parse_args():
398 parser.add_argument( 404 parser.add_argument(
399 "--sample_steps", 405 "--sample_steps",
400 type=int, 406 type=int,
401 default=15, 407 default=20,
402 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 408 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
403 ) 409 )
404 parser.add_argument( 410 parser.add_argument(
@@ -768,6 +774,7 @@ def main():
768 bucket_step_size=args.bucket_step_size, 774 bucket_step_size=args.bucket_step_size,
769 bucket_max_pixels=args.bucket_max_pixels, 775 bucket_max_pixels=args.bucket_max_pixels,
770 dropout=args.tag_dropout, 776 dropout=args.tag_dropout,
777 shuffle=args.tag_shuffle,
771 template_key=args.train_data_template, 778 template_key=args.train_data_template,
772 valid_set_size=args.valid_set_size, 779 valid_set_size=args.valid_set_size,
773 valid_set_repeat=args.valid_set_repeat, 780 valid_set_repeat=args.valid_set_repeat,