diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1a1f516..48a513c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -133,6 +133,12 @@ def parse_args(): | |||
133 | help="Tag dropout probability.", | 133 | help="Tag dropout probability.", |
134 | ) | 134 | ) |
135 | parser.add_argument( | 135 | parser.add_argument( |
136 | "--tag_shuffle", | ||
137 | type="store_true", | ||
138 | default=True, | ||
139 | help="Shuffle tags.", | ||
140 | ) | ||
141 | parser.add_argument( | ||
136 | "--vector_dropout", | 142 | "--vector_dropout", |
137 | type=int, | 143 | type=int, |
138 | default=0, | 144 | default=0, |
@@ -398,7 +404,7 @@ def parse_args(): | |||
398 | parser.add_argument( | 404 | parser.add_argument( |
399 | "--sample_steps", | 405 | "--sample_steps", |
400 | type=int, | 406 | type=int, |
401 | default=15, | 407 | default=20, |
402 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 408 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
403 | ) | 409 | ) |
404 | parser.add_argument( | 410 | parser.add_argument( |
@@ -768,6 +774,7 @@ def main(): | |||
768 | bucket_step_size=args.bucket_step_size, | 774 | bucket_step_size=args.bucket_step_size, |
769 | bucket_max_pixels=args.bucket_max_pixels, | 775 | bucket_max_pixels=args.bucket_max_pixels, |
770 | dropout=args.tag_dropout, | 776 | dropout=args.tag_dropout, |
777 | shuffle=args.tag_shuffle, | ||
771 | template_key=args.train_data_template, | 778 | template_key=args.train_data_template, |
772 | valid_set_size=args.valid_set_size, | 779 | valid_set_size=args.valid_set_size, |
773 | valid_set_repeat=args.valid_set_repeat, | 780 | valid_set_repeat=args.valid_set_repeat, |