diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/dreambooth.py b/dreambooth.py index 31416e9..5521b21 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -57,6 +57,11 @@ def parse_args(): | |||
57 | help="A folder containing the training data." | 57 | help="A folder containing the training data." |
58 | ) | 58 | ) |
59 | parser.add_argument( | 59 | parser.add_argument( |
60 | "--train_data_template", | ||
61 | type=str, | ||
62 | default="template", | ||
63 | ) | ||
64 | parser.add_argument( | ||
60 | "--instance_identifier", | 65 | "--instance_identifier", |
61 | type=str, | 66 | type=str, |
62 | default=None, | 67 | default=None, |
@@ -768,6 +773,7 @@ def main(): | |||
768 | repeats=args.repeats, | 773 | repeats=args.repeats, |
769 | dropout=args.tag_dropout, | 774 | dropout=args.tag_dropout, |
770 | center_crop=args.center_crop, | 775 | center_crop=args.center_crop, |
776 | template_key=args.train_data_template, | ||
771 | valid_set_size=args.valid_set_size, | 777 | valid_set_size=args.valid_set_size, |
772 | num_workers=args.dataloader_num_workers, | 778 | num_workers=args.dataloader_num_workers, |
773 | collate_fn=collate_fn | 779 | collate_fn=collate_fn |