diff options
-rw-r--r-- | train_dreambooth.py | 5 | ||||
-rw-r--r-- | train_ti.py | 8 |
2 files changed, 6 insertions, 7 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 21fe2fb..0182693 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -133,9 +133,8 @@ def parse_args(): | |||
133 | help="Tag dropout probability.", | 133 | help="Tag dropout probability.", |
134 | ) | 134 | ) |
135 | parser.add_argument( | 135 | parser.add_argument( |
136 | "--tag_shuffle", | 136 | "--no_tag_shuffle", |
137 | action="store_true", | 137 | action="store_true", |
138 | default=True, | ||
139 | help="Shuffle tags.", | 138 | help="Shuffle tags.", |
140 | ) | 139 | ) |
141 | parser.add_argument( | 140 | parser.add_argument( |
@@ -774,7 +773,7 @@ def main(): | |||
774 | bucket_step_size=args.bucket_step_size, | 773 | bucket_step_size=args.bucket_step_size, |
775 | bucket_max_pixels=args.bucket_max_pixels, | 774 | bucket_max_pixels=args.bucket_max_pixels, |
776 | dropout=args.tag_dropout, | 775 | dropout=args.tag_dropout, |
777 | shuffle=args.tag_shuffle, | 776 | shuffle=not args.no_tag_shuffle, |
778 | template_key=args.train_data_template, | 777 | template_key=args.train_data_template, |
779 | valid_set_size=args.valid_set_size, | 778 | valid_set_size=args.valid_set_size, |
780 | valid_set_repeat=args.valid_set_repeat, | 779 | valid_set_repeat=args.valid_set_repeat, |
diff --git a/train_ti.py b/train_ti.py index b88ccc3..4e2c3c5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -165,18 +165,18 @@ def parse_args(): | |||
165 | parser.add_argument( | 165 | parser.add_argument( |
166 | "--tag_dropout", | 166 | "--tag_dropout", |
167 | type=float, | 167 | type=float, |
168 | default=0.1, | 168 | default=0, |
169 | help="Tag dropout probability.", | 169 | help="Tag dropout probability.", |
170 | ) | 170 | ) |
171 | parser.add_argument( | 171 | parser.add_argument( |
172 | "--tag_shuffle", | 172 | "--no_tag_shuffle", |
173 | action="store_true", | 173 | action="store_true", |
174 | help="Shuffle tags.", | 174 | help="Shuffle tags.", |
175 | ) | 175 | ) |
176 | parser.add_argument( | 176 | parser.add_argument( |
177 | "--vector_dropout", | 177 | "--vector_dropout", |
178 | type=int, | 178 | type=int, |
179 | default=0, | 179 | default=0.1, |
180 | help="Vector dropout probability.", | 180 | help="Vector dropout probability.", |
181 | ) | 181 | ) |
182 | parser.add_argument( | 182 | parser.add_argument( |
@@ -750,7 +750,7 @@ def main(): | |||
750 | bucket_step_size=args.bucket_step_size, | 750 | bucket_step_size=args.bucket_step_size, |
751 | bucket_max_pixels=args.bucket_max_pixels, | 751 | bucket_max_pixels=args.bucket_max_pixels, |
752 | dropout=args.tag_dropout, | 752 | dropout=args.tag_dropout, |
753 | shuffle=args.tag_shuffle, | 753 | shuffle=not args.no_tag_shuffle, |
754 | template_key=args.train_data_template, | 754 | template_key=args.train_data_template, |
755 | valid_set_size=args.valid_set_size, | 755 | valid_set_size=args.valid_set_size, |
756 | valid_set_repeat=args.valid_set_repeat, | 756 | valid_set_repeat=args.valid_set_repeat, |