diff options
| -rw-r--r-- | train_dreambooth.py | 5 | ||||
| -rw-r--r-- | train_ti.py | 8 |
2 files changed, 6 insertions, 7 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 21fe2fb..0182693 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -133,9 +133,8 @@ def parse_args(): | |||
| 133 | help="Tag dropout probability.", | 133 | help="Tag dropout probability.", |
| 134 | ) | 134 | ) |
| 135 | parser.add_argument( | 135 | parser.add_argument( |
| 136 | "--tag_shuffle", | 136 | "--no_tag_shuffle", |
| 137 | action="store_true", | 137 | action="store_true", |
| 138 | default=True, | ||
| 139 | help="Shuffle tags.", | 138 | help="Shuffle tags.", |
| 140 | ) | 139 | ) |
| 141 | parser.add_argument( | 140 | parser.add_argument( |
| @@ -774,7 +773,7 @@ def main(): | |||
| 774 | bucket_step_size=args.bucket_step_size, | 773 | bucket_step_size=args.bucket_step_size, |
| 775 | bucket_max_pixels=args.bucket_max_pixels, | 774 | bucket_max_pixels=args.bucket_max_pixels, |
| 776 | dropout=args.tag_dropout, | 775 | dropout=args.tag_dropout, |
| 777 | shuffle=args.tag_shuffle, | 776 | shuffle=not args.no_tag_shuffle, |
| 778 | template_key=args.train_data_template, | 777 | template_key=args.train_data_template, |
| 779 | valid_set_size=args.valid_set_size, | 778 | valid_set_size=args.valid_set_size, |
| 780 | valid_set_repeat=args.valid_set_repeat, | 779 | valid_set_repeat=args.valid_set_repeat, |
diff --git a/train_ti.py b/train_ti.py index b88ccc3..4e2c3c5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -165,18 +165,18 @@ def parse_args(): | |||
| 165 | parser.add_argument( | 165 | parser.add_argument( |
| 166 | "--tag_dropout", | 166 | "--tag_dropout", |
| 167 | type=float, | 167 | type=float, |
| 168 | default=0.1, | 168 | default=0, |
| 169 | help="Tag dropout probability.", | 169 | help="Tag dropout probability.", |
| 170 | ) | 170 | ) |
| 171 | parser.add_argument( | 171 | parser.add_argument( |
| 172 | "--tag_shuffle", | 172 | "--no_tag_shuffle", |
| 173 | action="store_true", | 173 | action="store_true", |
| 174 | help="Shuffle tags.", | 174 | help="Shuffle tags.", |
| 175 | ) | 175 | ) |
| 176 | parser.add_argument( | 176 | parser.add_argument( |
| 177 | "--vector_dropout", | 177 | "--vector_dropout", |
| 178 | type=int, | 178 | type=int, |
| 179 | default=0, | 179 | default=0.1, |
| 180 | help="Vector dropout probability.", | 180 | help="Vector dropout probability.", |
| 181 | ) | 181 | ) |
| 182 | parser.add_argument( | 182 | parser.add_argument( |
| @@ -750,7 +750,7 @@ def main(): | |||
| 750 | bucket_step_size=args.bucket_step_size, | 750 | bucket_step_size=args.bucket_step_size, |
| 751 | bucket_max_pixels=args.bucket_max_pixels, | 751 | bucket_max_pixels=args.bucket_max_pixels, |
| 752 | dropout=args.tag_dropout, | 752 | dropout=args.tag_dropout, |
| 753 | shuffle=args.tag_shuffle, | 753 | shuffle=not args.no_tag_shuffle, |
| 754 | template_key=args.train_data_template, | 754 | template_key=args.train_data_template, |
| 755 | valid_set_size=args.valid_set_size, | 755 | valid_set_size=args.valid_set_size, |
| 756 | valid_set_repeat=args.valid_set_repeat, | 756 | valid_set_repeat=args.valid_set_repeat, |
