diff options
| -rw-r--r-- | data/csv.py | 16 | ||||
| -rw-r--r-- | train_dreambooth.py | 10 | ||||
| -rw-r--r-- | train_ti.py | 10 |
3 files changed, 26 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py index b45ac77..0810c2c 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -39,8 +39,8 @@ class CSVDataItem(NamedTuple): | |||
| 39 | instance_image_path: Path | 39 | instance_image_path: Path |
| 40 | class_image_path: Path | 40 | class_image_path: Path |
| 41 | prompt: list[str] | 41 | prompt: list[str] |
| 42 | cprompt: list[str] | 42 | cprompt: str |
| 43 | nprompt: list[str] | 43 | nprompt: str |
| 44 | 44 | ||
| 45 | 45 | ||
| 46 | class CSVDataModule(): | 46 | class CSVDataModule(): |
| @@ -105,14 +105,14 @@ class CSVDataModule(): | |||
| 105 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 105 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 106 | expansions | 106 | expansions |
| 107 | ), | 107 | ), |
| 108 | prompt_to_keywords( | 108 | keywords_to_prompt(prompt_to_keywords( |
| 109 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 109 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 110 | expansions | 110 | expansions |
| 111 | ), | 111 | )), |
| 112 | prompt_to_keywords( | 112 | keywords_to_prompt(prompt_to_keywords( |
| 113 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 113 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 114 | expansions | 114 | expansions |
| 115 | ), | 115 | )), |
| 116 | ) | 116 | ) |
| 117 | for item in data | 117 | for item in data |
| 118 | ] | 118 | ] |
| @@ -261,8 +261,8 @@ class CSVDataset(Dataset): | |||
| 261 | example = {} | 261 | example = {} |
| 262 | 262 | ||
| 263 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) | 263 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) |
| 264 | example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"]) | 264 | example["cprompts"] = unprocessed_example["cprompts"] |
| 265 | example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"]) | 265 | example["nprompts"] = unprocessed_example["nprompts"] |
| 266 | 266 | ||
| 267 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 267 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 268 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) | 268 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1a79b2b..c7899a0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -105,6 +105,12 @@ def parse_args(): | |||
| 105 | help="How many class images to generate." | 105 | help="How many class images to generate." |
| 106 | ) | 106 | ) |
| 107 | parser.add_argument( | 107 | parser.add_argument( |
| 108 | "--class_image_dir", | ||
| 109 | type=str, | ||
| 110 | default="cls", | ||
| 111 | help="The directory where class images will be saved.", | ||
| 112 | ) | ||
| 113 | parser.add_argument( | ||
| 108 | "--repeats", | 114 | "--repeats", |
| 109 | type=int, | 115 | type=int, |
| 110 | default=1, | 116 | default=1, |
| @@ -653,7 +659,7 @@ def main(): | |||
| 653 | data_file=args.train_data_file, | 659 | data_file=args.train_data_file, |
| 654 | batch_size=args.train_batch_size, | 660 | batch_size=args.train_batch_size, |
| 655 | prompt_processor=prompt_processor, | 661 | prompt_processor=prompt_processor, |
| 656 | class_subdir="cls", | 662 | class_subdir=args.class_image_dir, |
| 657 | num_class_images=args.num_class_images, | 663 | num_class_images=args.num_class_images, |
| 658 | size=args.resolution, | 664 | size=args.resolution, |
| 659 | repeats=args.repeats, | 665 | repeats=args.repeats, |
| @@ -696,6 +702,8 @@ def main(): | |||
| 696 | images = pipeline( | 702 | images = pipeline( |
| 697 | prompt=prompt, | 703 | prompt=prompt, |
| 698 | negative_prompt=nprompt, | 704 | negative_prompt=nprompt, |
| 705 | height=args.sample_image_size, | ||
| 706 | width=args.sample_image_size, | ||
| 699 | num_inference_steps=args.sample_steps | 707 | num_inference_steps=args.sample_steps |
| 700 | ).images | 708 | ).images |
| 701 | 709 | ||
diff --git a/train_ti.py b/train_ti.py index cc208f0..52bd675 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -87,6 +87,12 @@ def parse_args(): | |||
| 87 | help="How many class images to generate." | 87 | help="How many class images to generate." |
| 88 | ) | 88 | ) |
| 89 | parser.add_argument( | 89 | parser.add_argument( |
| 90 | "--class_image_dir", | ||
| 91 | type=str, | ||
| 92 | default="cls", | ||
| 93 | help="The directory where class images will be saved.", | ||
| 94 | ) | ||
| 95 | parser.add_argument( | ||
| 90 | "--repeats", | 96 | "--repeats", |
| 91 | type=int, | 97 | type=int, |
| 92 | default=1, | 98 | default=1, |
| @@ -586,7 +592,7 @@ def main(): | |||
| 586 | data_file=args.train_data_file, | 592 | data_file=args.train_data_file, |
| 587 | batch_size=args.train_batch_size, | 593 | batch_size=args.train_batch_size, |
| 588 | prompt_processor=prompt_processor, | 594 | prompt_processor=prompt_processor, |
| 589 | class_subdir="cls", | 595 | class_subdir=args.class_image_dir, |
| 590 | num_class_images=args.num_class_images, | 596 | num_class_images=args.num_class_images, |
| 591 | size=args.resolution, | 597 | size=args.resolution, |
| 592 | repeats=args.repeats, | 598 | repeats=args.repeats, |
| @@ -630,6 +636,8 @@ def main(): | |||
| 630 | images = pipeline( | 636 | images = pipeline( |
| 631 | prompt=prompt, | 637 | prompt=prompt, |
| 632 | negative_prompt=nprompt, | 638 | negative_prompt=nprompt, |
| 639 | height=args.sample_image_size, | ||
| 640 | width=args.sample_image_size, | ||
| 633 | num_inference_steps=args.sample_steps | 641 | num_inference_steps=args.sample_steps |
| 634 | ).images | 642 | ).images |
| 635 | 643 | ||
