summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-13 20:49:57 +0100
committerVolpeon <git@volpeon.ink>2022-12-13 20:49:57 +0100
commitb73469706091c8aaf3f028de96ab017f5a845639 (patch)
tree892208ff6c19a11b9870e0ba298d88fb0d4bd5ba /dreambooth.py
parentFixed sample/checkpoint frequency (diff)
downloadtextual-inversion-diff-b73469706091c8aaf3f028de96ab017f5a845639.tar.gz
textual-inversion-diff-b73469706091c8aaf3f028de96ab017f5a845639.tar.bz2
textual-inversion-diff-b73469706091c8aaf3f028de96ab017f5a845639.zip
Optimized Textual Inversion training by filtering dataset by existence of added tokens
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 31416e9..5521b21 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -57,6 +57,11 @@ def parse_args():
57 help="A folder containing the training data." 57 help="A folder containing the training data."
58 ) 58 )
59 parser.add_argument( 59 parser.add_argument(
60 "--train_data_template",
61 type=str,
62 default="template",
63 )
64 parser.add_argument(
60 "--instance_identifier", 65 "--instance_identifier",
61 type=str, 66 type=str,
62 default=None, 67 default=None,
@@ -768,6 +773,7 @@ def main():
768 repeats=args.repeats, 773 repeats=args.repeats,
769 dropout=args.tag_dropout, 774 dropout=args.tag_dropout,
770 center_crop=args.center_crop, 775 center_crop=args.center_crop,
776 template_key=args.train_data_template,
771 valid_set_size=args.valid_set_size, 777 valid_set_size=args.valid_set_size,
772 num_workers=args.dataloader_num_workers, 778 num_workers=args.dataloader_num_workers,
773 collate_fn=collate_fn 779 collate_fn=collate_fn