summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-24 09:35:57 +0100
committerVolpeon <git@volpeon.ink>2022-12-24 09:35:57 +0100
commitdfc51d6d74410acefab86d2938a2b864be603668 (patch)
treebed57fcd481bc243324950f93e890294703533f6 /train_dreambooth.py
parentFix (diff)
downloadtextual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.tar.gz
textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.tar.bz2
textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.zip
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1a79b2b..c7899a0 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -105,6 +105,12 @@ def parse_args():
105 help="How many class images to generate." 105 help="How many class images to generate."
106 ) 106 )
107 parser.add_argument( 107 parser.add_argument(
108 "--class_image_dir",
109 type=str,
110 default="cls",
111 help="The directory where class images will be saved.",
112 )
113 parser.add_argument(
108 "--repeats", 114 "--repeats",
109 type=int, 115 type=int,
110 default=1, 116 default=1,
@@ -653,7 +659,7 @@ def main():
653 data_file=args.train_data_file, 659 data_file=args.train_data_file,
654 batch_size=args.train_batch_size, 660 batch_size=args.train_batch_size,
655 prompt_processor=prompt_processor, 661 prompt_processor=prompt_processor,
656 class_subdir="cls", 662 class_subdir=args.class_image_dir,
657 num_class_images=args.num_class_images, 663 num_class_images=args.num_class_images,
658 size=args.resolution, 664 size=args.resolution,
659 repeats=args.repeats, 665 repeats=args.repeats,
@@ -696,6 +702,8 @@ def main():
696 images = pipeline( 702 images = pipeline(
697 prompt=prompt, 703 prompt=prompt,
698 negative_prompt=nprompt, 704 negative_prompt=nprompt,
705 height=args.sample_image_size,
706 width=args.sample_image_size,
699 num_inference_steps=args.sample_steps 707 num_inference_steps=args.sample_steps
700 ).images 708 ).images
701 709