summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py16
-rw-r--r--train_dreambooth.py10
-rw-r--r--train_ti.py10
3 files changed, 26 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py
index b45ac77..0810c2c 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -39,8 +39,8 @@ class CSVDataItem(NamedTuple):
39 instance_image_path: Path 39 instance_image_path: Path
40 class_image_path: Path 40 class_image_path: Path
41 prompt: list[str] 41 prompt: list[str]
42 cprompt: list[str] 42 cprompt: str
43 nprompt: list[str] 43 nprompt: str
44 44
45 45
46class CSVDataModule(): 46class CSVDataModule():
@@ -105,14 +105,14 @@ class CSVDataModule():
105 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 105 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
106 expansions 106 expansions
107 ), 107 ),
108 prompt_to_keywords( 108 keywords_to_prompt(prompt_to_keywords(
109 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 109 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
110 expansions 110 expansions
111 ), 111 )),
112 prompt_to_keywords( 112 keywords_to_prompt(prompt_to_keywords(
113 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 113 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
114 expansions 114 expansions
115 ), 115 )),
116 ) 116 )
117 for item in data 117 for item in data
118 ] 118 ]
@@ -261,8 +261,8 @@ class CSVDataset(Dataset):
261 example = {} 261 example = {}
262 262
263 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) 263 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True)
264 example["cprompts"] = keywords_to_prompt(unprocessed_example["cprompts"]) 264 example["cprompts"] = unprocessed_example["cprompts"]
265 example["nprompts"] = keywords_to_prompt(unprocessed_example["nprompts"]) 265 example["nprompts"] = unprocessed_example["nprompts"]
266 266
267 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 267 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
268 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) 268 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"])
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1a79b2b..c7899a0 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -105,6 +105,12 @@ def parse_args():
105 help="How many class images to generate." 105 help="How many class images to generate."
106 ) 106 )
107 parser.add_argument( 107 parser.add_argument(
108 "--class_image_dir",
109 type=str,
110 default="cls",
111 help="The directory where class images will be saved.",
112 )
113 parser.add_argument(
108 "--repeats", 114 "--repeats",
109 type=int, 115 type=int,
110 default=1, 116 default=1,
@@ -653,7 +659,7 @@ def main():
653 data_file=args.train_data_file, 659 data_file=args.train_data_file,
654 batch_size=args.train_batch_size, 660 batch_size=args.train_batch_size,
655 prompt_processor=prompt_processor, 661 prompt_processor=prompt_processor,
656 class_subdir="cls", 662 class_subdir=args.class_image_dir,
657 num_class_images=args.num_class_images, 663 num_class_images=args.num_class_images,
658 size=args.resolution, 664 size=args.resolution,
659 repeats=args.repeats, 665 repeats=args.repeats,
@@ -696,6 +702,8 @@ def main():
696 images = pipeline( 702 images = pipeline(
697 prompt=prompt, 703 prompt=prompt,
698 negative_prompt=nprompt, 704 negative_prompt=nprompt,
705 height=args.sample_image_size,
706 width=args.sample_image_size,
699 num_inference_steps=args.sample_steps 707 num_inference_steps=args.sample_steps
700 ).images 708 ).images
701 709
diff --git a/train_ti.py b/train_ti.py
index cc208f0..52bd675 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -87,6 +87,12 @@ def parse_args():
87 help="How many class images to generate." 87 help="How many class images to generate."
88 ) 88 )
89 parser.add_argument( 89 parser.add_argument(
90 "--class_image_dir",
91 type=str,
92 default="cls",
93 help="The directory where class images will be saved.",
94 )
95 parser.add_argument(
90 "--repeats", 96 "--repeats",
91 type=int, 97 type=int,
92 default=1, 98 default=1,
@@ -586,7 +592,7 @@ def main():
586 data_file=args.train_data_file, 592 data_file=args.train_data_file,
587 batch_size=args.train_batch_size, 593 batch_size=args.train_batch_size,
588 prompt_processor=prompt_processor, 594 prompt_processor=prompt_processor,
589 class_subdir="cls", 595 class_subdir=args.class_image_dir,
590 num_class_images=args.num_class_images, 596 num_class_images=args.num_class_images,
591 size=args.resolution, 597 size=args.resolution,
592 repeats=args.repeats, 598 repeats=args.repeats,
@@ -630,6 +636,8 @@ def main():
630 images = pipeline( 636 images = pipeline(
631 prompt=prompt, 637 prompt=prompt,
632 negative_prompt=nprompt, 638 negative_prompt=nprompt,
639 height=args.sample_image_size,
640 width=args.sample_image_size,
633 num_inference_steps=args.sample_steps 641 num_inference_steps=args.sample_steps
634 ).images 642 ).images
635 643