From dfc51d6d74410acefab86d2938a2b864be603668 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 09:35:57 +0100 Subject: Update --- train_ti.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index cc208f0..52bd675 100644 --- a/train_ti.py +++ b/train_ti.py @@ -86,6 +86,12 @@ def parse_args(): default=1, help="How many class images to generate." ) + parser.add_argument( + "--class_image_dir", + type=str, + default="cls", + help="The directory where class images will be saved.", + ) parser.add_argument( "--repeats", type=int, @@ -586,7 +592,7 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, - class_subdir="cls", + class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, @@ -630,6 +636,8 @@ def main(): images = pipeline( prompt=prompt, negative_prompt=nprompt, + height=args.sample_image_size, + width=args.sample_image_size, num_inference_steps=args.sample_steps ).images -- cgit v1.2.3-54-g00ecf