summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py5
-rw-r--r--train_ti.py8
2 files changed, 6 insertions, 7 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 21fe2fb..0182693 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -133,9 +133,8 @@ def parse_args():
133 help="Tag dropout probability.", 133 help="Tag dropout probability.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
136 "--tag_shuffle", 136 "--no_tag_shuffle",
137 action="store_true", 137 action="store_true",
138 default=True,
139 help="Shuffle tags.", 138 help="Shuffle tags.",
140 ) 139 )
141 parser.add_argument( 140 parser.add_argument(
@@ -774,7 +773,7 @@ def main():
774 bucket_step_size=args.bucket_step_size, 773 bucket_step_size=args.bucket_step_size,
775 bucket_max_pixels=args.bucket_max_pixels, 774 bucket_max_pixels=args.bucket_max_pixels,
776 dropout=args.tag_dropout, 775 dropout=args.tag_dropout,
777 shuffle=args.tag_shuffle, 776 shuffle=not args.no_tag_shuffle,
778 template_key=args.train_data_template, 777 template_key=args.train_data_template,
779 valid_set_size=args.valid_set_size, 778 valid_set_size=args.valid_set_size,
780 valid_set_repeat=args.valid_set_repeat, 779 valid_set_repeat=args.valid_set_repeat,
diff --git a/train_ti.py b/train_ti.py
index b88ccc3..4e2c3c5 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -165,18 +165,18 @@ def parse_args():
165 parser.add_argument( 165 parser.add_argument(
166 "--tag_dropout", 166 "--tag_dropout",
167 type=float, 167 type=float,
168 default=0.1, 168 default=0,
169 help="Tag dropout probability.", 169 help="Tag dropout probability.",
170 ) 170 )
171 parser.add_argument( 171 parser.add_argument(
172 "--tag_shuffle", 172 "--no_tag_shuffle",
173 action="store_true", 173 action="store_true",
174 help="Shuffle tags.", 174 help="Shuffle tags.",
175 ) 175 )
176 parser.add_argument( 176 parser.add_argument(
177 "--vector_dropout", 177 "--vector_dropout",
178 type=int, 178 type=int,
179 default=0, 179 default=0.1,
180 help="Vector dropout probability.", 180 help="Vector dropout probability.",
181 ) 181 )
182 parser.add_argument( 182 parser.add_argument(
@@ -750,7 +750,7 @@ def main():
750 bucket_step_size=args.bucket_step_size, 750 bucket_step_size=args.bucket_step_size,
751 bucket_max_pixels=args.bucket_max_pixels, 751 bucket_max_pixels=args.bucket_max_pixels,
752 dropout=args.tag_dropout, 752 dropout=args.tag_dropout,
753 shuffle=args.tag_shuffle, 753 shuffle=not args.no_tag_shuffle,
754 template_key=args.train_data_template, 754 template_key=args.train_data_template,
755 valid_set_size=args.valid_set_size, 755 valid_set_size=args.valid_set_size,
756 valid_set_repeat=args.valid_set_repeat, 756 valid_set_repeat=args.valid_set_repeat,