diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1a79b2b..c7899a0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -105,6 +105,12 @@ def parse_args(): | |||
| 105 | help="How many class images to generate." | 105 | help="How many class images to generate." |
| 106 | ) | 106 | ) |
| 107 | parser.add_argument( | 107 | parser.add_argument( |
| 108 | "--class_image_dir", | ||
| 109 | type=str, | ||
| 110 | default="cls", | ||
| 111 | help="The directory where class images will be saved.", | ||
| 112 | ) | ||
| 113 | parser.add_argument( | ||
| 108 | "--repeats", | 114 | "--repeats", |
| 109 | type=int, | 115 | type=int, |
| 110 | default=1, | 116 | default=1, |
| @@ -653,7 +659,7 @@ def main(): | |||
| 653 | data_file=args.train_data_file, | 659 | data_file=args.train_data_file, |
| 654 | batch_size=args.train_batch_size, | 660 | batch_size=args.train_batch_size, |
| 655 | prompt_processor=prompt_processor, | 661 | prompt_processor=prompt_processor, |
| 656 | class_subdir="cls", | 662 | class_subdir=args.class_image_dir, |
| 657 | num_class_images=args.num_class_images, | 663 | num_class_images=args.num_class_images, |
| 658 | size=args.resolution, | 664 | size=args.resolution, |
| 659 | repeats=args.repeats, | 665 | repeats=args.repeats, |
| @@ -696,6 +702,8 @@ def main(): | |||
| 696 | images = pipeline( | 702 | images = pipeline( |
| 697 | prompt=prompt, | 703 | prompt=prompt, |
| 698 | negative_prompt=nprompt, | 704 | negative_prompt=nprompt, |
| 705 | height=args.sample_image_size, | ||
| 706 | width=args.sample_image_size, | ||
| 699 | num_inference_steps=args.sample_steps | 707 | num_inference_steps=args.sample_steps |
| 700 | ).images | 708 | ).images |
| 701 | 709 | ||
