diff options
author | Volpeon <git@volpeon.ink> | 2022-12-24 09:35:57 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-24 09:35:57 +0100 |
commit | dfc51d6d74410acefab86d2938a2b864be603668 (patch) | |
tree | bed57fcd481bc243324950f93e890294703533f6 /train_dreambooth.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.tar.gz textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.tar.bz2 textual-inversion-diff-dfc51d6d74410acefab86d2938a2b864be603668.zip |
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1a79b2b..c7899a0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -105,6 +105,12 @@ def parse_args(): | |||
105 | help="How many class images to generate." | 105 | help="How many class images to generate." |
106 | ) | 106 | ) |
107 | parser.add_argument( | 107 | parser.add_argument( |
108 | "--class_image_dir", | ||
109 | type=str, | ||
110 | default="cls", | ||
111 | help="The directory where class images will be saved.", | ||
112 | ) | ||
113 | parser.add_argument( | ||
108 | "--repeats", | 114 | "--repeats", |
109 | type=int, | 115 | type=int, |
110 | default=1, | 116 | default=1, |
@@ -653,7 +659,7 @@ def main(): | |||
653 | data_file=args.train_data_file, | 659 | data_file=args.train_data_file, |
654 | batch_size=args.train_batch_size, | 660 | batch_size=args.train_batch_size, |
655 | prompt_processor=prompt_processor, | 661 | prompt_processor=prompt_processor, |
656 | class_subdir="cls", | 662 | class_subdir=args.class_image_dir, |
657 | num_class_images=args.num_class_images, | 663 | num_class_images=args.num_class_images, |
658 | size=args.resolution, | 664 | size=args.resolution, |
659 | repeats=args.repeats, | 665 | repeats=args.repeats, |
@@ -696,6 +702,8 @@ def main(): | |||
696 | images = pipeline( | 702 | images = pipeline( |
697 | prompt=prompt, | 703 | prompt=prompt, |
698 | negative_prompt=nprompt, | 704 | negative_prompt=nprompt, |
705 | height=args.sample_image_size, | ||
706 | width=args.sample_image_size, | ||
699 | num_inference_steps=args.sample_steps | 707 | num_inference_steps=args.sample_steps |
700 | ).images | 708 | ).images |
701 | 709 | ||