From dfc51d6d74410acefab86d2938a2b864be603668 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 09:35:57 +0100 Subject: Update --- data/csv.py | 16 ++++++++-------- train_dreambooth.py | 10 +++++++++- train_ti.py | 10 +++++++++- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/data/csv.py b/data/csv.py index b45ac77..0810c2c 100644 --- a/data/csv.py +++ b/data/csv.py @@ -39,8 +39,8 @@ class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] - cprompt: list[str] - nprompt: list[str] + cprompt: str + nprompt: str class CSVDataModule(): @@ -105,14 +105,14 @@ class CSVDataModule(): prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions ), - prompt_to_keywords( + keywords_to_prompt(prompt_to_keywords( cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions - ), - prompt_to_keywords( + )), + keywords_to_prompt(prompt_to_keywords( nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions - ), + )), ) for item in data ] @@ -261,8 +261,8 @@ class CSVDataset(Dataset): example = {} example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) - example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"]) - example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"]) + example["cprompts"] = unprocessed_example["cprompts"] + example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) diff --git a/train_dreambooth.py b/train_dreambooth.py index 1a79b2b..c7899a0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -104,6 +104,12 @@ def parse_args(): default=1, help="How many class images to generate." ) + parser.add_argument( + "--class_image_dir", + type=str, + default="cls", + help="The directory where class images will be saved.", + ) parser.add_argument( "--repeats", type=int, @@ -653,7 +659,7 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, - class_subdir="cls", + class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, @@ -696,6 +702,8 @@ def main(): images = pipeline( prompt=prompt, negative_prompt=nprompt, + height=args.sample_image_size, + width=args.sample_image_size, num_inference_steps=args.sample_steps ).images diff --git a/train_ti.py b/train_ti.py index cc208f0..52bd675 100644 --- a/train_ti.py +++ b/train_ti.py @@ -86,6 +86,12 @@ def parse_args(): default=1, help="How many class images to generate." ) + parser.add_argument( + "--class_image_dir", + type=str, + default="cls", + help="The directory where class images will be saved.", + ) parser.add_argument( "--repeats", type=int, @@ -586,7 +592,7 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, - class_subdir="cls", + class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, @@ -630,6 +636,8 @@ def main(): images = pipeline( prompt=prompt, negative_prompt=nprompt, + height=args.sample_image_size, + width=args.sample_image_size, num_inference_steps=args.sample_steps ).images -- cgit v1.2.3-54-g00ecf