summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
commit1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch)
tree42d3302610046dbc5d39d254f7a2d5d5f601aa18
parentFix (diff)
downloadtextual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip
Improved class prompt handling
-rw-r--r--data/csv.py29
-rw-r--r--environment.yaml2
-rw-r--r--train_dreambooth.py30
-rw-r--r--train_ti.py29
-rw-r--r--training/util.py4
5 files changed, 31 insertions, 63 deletions
diff --git a/data/csv.py b/data/csv.py
index d400757..e25dd3f 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -38,6 +38,7 @@ class CSVDataItem(NamedTuple):
38 instance_image_path: Path 38 instance_image_path: Path
39 class_image_path: Path 39 class_image_path: Path
40 prompt: list[str] 40 prompt: list[str]
41 cprompt: str
41 nprompt: str 42 nprompt: str
42 43
43 44
@@ -47,8 +48,6 @@ class CSVDataModule():
47 batch_size: int, 48 batch_size: int,
48 data_file: str, 49 data_file: str,
49 prompt_processor: PromptProcessor, 50 prompt_processor: PromptProcessor,
50 instance_identifier: str,
51 class_identifier: Optional[str] = None,
52 class_subdir: str = "cls", 51 class_subdir: str = "cls",
53 num_class_images: int = 100, 52 num_class_images: int = 100,
54 size: int = 512, 53 size: int = 512,
@@ -77,8 +76,6 @@ class CSVDataModule():
77 self.num_class_images = num_class_images 76 self.num_class_images = num_class_images
78 77
79 self.prompt_processor = prompt_processor 78 self.prompt_processor = prompt_processor
80 self.instance_identifier = instance_identifier
81 self.class_identifier = class_identifier
82 self.size = size 79 self.size = size
83 self.repeats = repeats 80 self.repeats = repeats
84 self.dropout = dropout 81 self.dropout = dropout
@@ -96,14 +93,18 @@ class CSVDataModule():
96 def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: 93 def prepare_items(self, template, expansions, data) -> list[CSVDataItem]:
97 image = template["image"] if "image" in template else "{}" 94 image = template["image"] if "image" in template else "{}"
98 prompt = template["prompt"] if "prompt" in template else "{content}" 95 prompt = template["prompt"] if "prompt" in template else "{content}"
96 cprompt = template["cprompt"] if "cprompt" in template else "{content}"
99 nprompt = template["nprompt"] if "nprompt" in template else "{content}" 97 nprompt = template["nprompt"] if "nprompt" in template else "{content}"
100 98
101 return [ 99 return [
102 CSVDataItem( 100 CSVDataItem(
103 self.data_root.joinpath(image.format(item["image"])), 101 self.data_root.joinpath(image.format(item["image"])),
104 None, 102 None,
105 prompt_to_keywords(prompt.format( 103 prompt_to_keywords(
106 **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), 104 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
105 expansions
106 ),
107 cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")),
107 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 108 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
108 ) 109 )
109 for item in data 110 for item in data
@@ -123,6 +124,7 @@ class CSVDataModule():
123 item.instance_image_path, 124 item.instance_image_path,
124 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), 125 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"),
125 item.prompt, 126 item.prompt,
127 item.cprompt,
126 item.nprompt, 128 item.nprompt,
127 ) 129 )
128 for item in items 130 for item in items
@@ -160,12 +162,10 @@ class CSVDataModule():
160 162
161 def setup(self, stage=None): 163 def setup(self, stage=None):
162 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, 164 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size,
163 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
164 num_class_images=self.num_class_images, 165 num_class_images=self.num_class_images,
165 size=self.size, interpolation=self.interpolation, 166 size=self.size, interpolation=self.interpolation,
166 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) 167 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout)
167 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, 168 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size,
168 instance_identifier=self.instance_identifier,
169 size=self.size, interpolation=self.interpolation, 169 size=self.size, interpolation=self.interpolation,
170 center_crop=self.center_crop) 170 center_crop=self.center_crop)
171 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 171 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
@@ -187,9 +187,7 @@ class CSVDataset(Dataset):
187 self, 187 self,
188 data: List[CSVDataItem], 188 data: List[CSVDataItem],
189 prompt_processor: PromptProcessor, 189 prompt_processor: PromptProcessor,
190 instance_identifier: str,
191 batch_size: int = 1, 190 batch_size: int = 1,
192 class_identifier: Optional[str] = None,
193 num_class_images: int = 0, 191 num_class_images: int = 0,
194 size: int = 512, 192 size: int = 512,
195 repeats: int = 1, 193 repeats: int = 1,
@@ -201,8 +199,6 @@ class CSVDataset(Dataset):
201 self.data = data 199 self.data = data
202 self.prompt_processor = prompt_processor 200 self.prompt_processor = prompt_processor
203 self.batch_size = batch_size 201 self.batch_size = batch_size
204 self.instance_identifier = instance_identifier
205 self.class_identifier = class_identifier
206 self.num_class_images = num_class_images 202 self.num_class_images = num_class_images
207 self.dropout = dropout 203 self.dropout = dropout
208 self.image_cache = {} 204 self.image_cache = {}
@@ -239,14 +235,12 @@ class CSVDataset(Dataset):
239 235
240 return image 236 return image
241 237
242 def get_input_ids(self, prompt, identifier):
243 return self.prompt_processor.get_input_ids(prompt.format(identifier))
244
245 def get_example(self, i): 238 def get_example(self, i):
246 item = self.data[i % self.num_instance_images] 239 item = self.data[i % self.num_instance_images]
247 240
248 example = {} 241 example = {}
249 example["prompts"] = item.prompt 242 example["prompts"] = item.prompt
243 example["cprompts"] = item.cprompt
250 example["nprompts"] = item.nprompt 244 example["nprompts"] = item.nprompt
251 example["instance_images"] = self.get_image(item.instance_image_path) 245 example["instance_images"] = self.get_image(item.instance_image_path)
252 if self.num_class_images != 0: 246 if self.num_class_images != 0:
@@ -260,13 +254,14 @@ class CSVDataset(Dataset):
260 example = {} 254 example = {}
261 255
262 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) 256 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout)
257 example["cprompts"] = unprocessed_example["cprompts"]
263 example["nprompts"] = unprocessed_example["nprompts"] 258 example["nprompts"] = unprocessed_example["nprompts"]
264 259
265 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 260 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
266 example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) 261 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"])
267 262
268 if self.num_class_images != 0: 263 if self.num_class_images != 0:
269 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 264 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
270 example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) 265 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"])
271 266
272 return example 267 return example
diff --git a/environment.yaml b/environment.yaml
index e598f72..179fa38 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -21,4 +21,4 @@ dependencies:
21 - test-tube>=0.7.5 21 - test-tube>=0.7.5
22 - transformers==4.25.1 22 - transformers==4.25.1
23 - triton==2.0.0.dev20221202 23 - triton==2.0.0.dev20221202
24 - xformers==0.0.16rc391 24 - xformers==0.0.16rc396
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 9749c62..ff67d12 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -62,16 +62,10 @@ def parse_args():
62 default="template", 62 default="template",
63 ) 63 )
64 parser.add_argument( 64 parser.add_argument(
65 "--instance_identifier", 65 "--project",
66 type=str, 66 type=str,
67 default=None, 67 default=None,
68 help="A token to use as a placeholder for the concept.", 68 help="The name of the current project.",
69 )
70 parser.add_argument(
71 "--class_identifier",
72 type=str,
73 default=None,
74 help="A token to use as a placeholder for the concept.",
75 ) 69 )
76 parser.add_argument( 70 parser.add_argument(
77 "--placeholder_token", 71 "--placeholder_token",
@@ -364,8 +358,8 @@ def parse_args():
364 if args.pretrained_model_name_or_path is None: 358 if args.pretrained_model_name_or_path is None:
365 raise ValueError("You must specify --pretrained_model_name_or_path") 359 raise ValueError("You must specify --pretrained_model_name_or_path")
366 360
367 if args.instance_identifier is None: 361 if args.project is None:
368 raise ValueError("You must specify --instance_identifier") 362 raise ValueError("You must specify --project")
369 363
370 if isinstance(args.initializer_token, str): 364 if isinstance(args.initializer_token, str):
371 args.initializer_token = [args.initializer_token] 365 args.initializer_token = [args.initializer_token]
@@ -397,7 +391,6 @@ class Checkpointer(CheckpointerBase):
397 text_encoder, 391 text_encoder,
398 scheduler, 392 scheduler,
399 output_dir: Path, 393 output_dir: Path,
400 instance_identifier,
401 placeholder_token, 394 placeholder_token,
402 placeholder_token_id, 395 placeholder_token_id,
403 sample_image_size, 396 sample_image_size,
@@ -408,7 +401,6 @@ class Checkpointer(CheckpointerBase):
408 super().__init__( 401 super().__init__(
409 datamodule=datamodule, 402 datamodule=datamodule,
410 output_dir=output_dir, 403 output_dir=output_dir,
411 instance_identifier=instance_identifier,
412 placeholder_token=placeholder_token, 404 placeholder_token=placeholder_token,
413 placeholder_token_id=placeholder_token_id, 405 placeholder_token_id=placeholder_token_id,
414 sample_image_size=sample_image_size, 406 sample_image_size=sample_image_size,
@@ -481,13 +473,8 @@ def main():
481 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 473 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
482 ) 474 )
483 475
484 instance_identifier = args.instance_identifier
485
486 if len(args.placeholder_token) != 0:
487 instance_identifier = instance_identifier.format(args.placeholder_token[0])
488
489 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 476 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
490 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 477 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
491 basepath.mkdir(parents=True, exist_ok=True) 478 basepath.mkdir(parents=True, exist_ok=True)
492 479
493 accelerator = Accelerator( 480 accelerator = Accelerator(
@@ -637,6 +624,7 @@ def main():
637 624
638 def collate_fn(examples): 625 def collate_fn(examples):
639 prompts = [example["prompts"] for example in examples] 626 prompts = [example["prompts"] for example in examples]
627 cprompts = [example["cprompts"] for example in examples]
640 nprompts = [example["nprompts"] for example in examples] 628 nprompts = [example["nprompts"] for example in examples]
641 input_ids = [example["instance_prompt_ids"] for example in examples] 629 input_ids = [example["instance_prompt_ids"] for example in examples]
642 pixel_values = [example["instance_images"] for example in examples] 630 pixel_values = [example["instance_images"] for example in examples]
@@ -653,6 +641,7 @@ def main():
653 641
654 batch = { 642 batch = {
655 "prompts": prompts, 643 "prompts": prompts,
644 "cprompts": cprompts,
656 "nprompts": nprompts, 645 "nprompts": nprompts,
657 "input_ids": inputs.input_ids, 646 "input_ids": inputs.input_ids,
658 "pixel_values": pixel_values, 647 "pixel_values": pixel_values,
@@ -664,8 +653,6 @@ def main():
664 data_file=args.train_data_file, 653 data_file=args.train_data_file,
665 batch_size=args.train_batch_size, 654 batch_size=args.train_batch_size,
666 prompt_processor=prompt_processor, 655 prompt_processor=prompt_processor,
667 instance_identifier=instance_identifier,
668 class_identifier=args.class_identifier,
669 class_subdir="cls", 656 class_subdir="cls",
670 num_class_images=args.num_class_images, 657 num_class_images=args.num_class_images,
671 size=args.resolution, 658 size=args.resolution,
@@ -703,7 +690,7 @@ def main():
703 with torch.autocast("cuda"), torch.inference_mode(): 690 with torch.autocast("cuda"), torch.inference_mode():
704 for batch in batched_data: 691 for batch in batched_data:
705 image_name = [item.class_image_path for item in batch] 692 image_name = [item.class_image_path for item in batch]
706 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] 693 prompt = [item.cprompt for item in batch]
707 nprompt = [item.nprompt for item in batch] 694 nprompt = [item.nprompt for item in batch]
708 695
709 images = pipeline( 696 images = pipeline(
@@ -812,7 +799,6 @@ def main():
812 text_encoder=text_encoder, 799 text_encoder=text_encoder,
813 scheduler=checkpoint_scheduler, 800 scheduler=checkpoint_scheduler,
814 output_dir=basepath, 801 output_dir=basepath,
815 instance_identifier=instance_identifier,
816 placeholder_token=args.placeholder_token, 802 placeholder_token=args.placeholder_token,
817 placeholder_token_id=placeholder_token_id, 803 placeholder_token_id=placeholder_token_id,
818 sample_image_size=args.sample_image_size, 804 sample_image_size=args.sample_image_size,
diff --git a/train_ti.py b/train_ti.py
index 9d06c50..55daa35 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -63,16 +63,10 @@ def parse_args():
63 default="template", 63 default="template",
64 ) 64 )
65 parser.add_argument( 65 parser.add_argument(
66 "--instance_identifier", 66 "--project",
67 type=str, 67 type=str,
68 default=None, 68 default=None,
69 help="A token to use as a placeholder for the concept.", 69 help="The name of the current project.",
70 )
71 parser.add_argument(
72 "--class_identifier",
73 type=str,
74 default=None,
75 help="A token to use as a placeholder for the concept.",
76 ) 70 )
77 parser.add_argument( 71 parser.add_argument(
78 "--placeholder_token", 72 "--placeholder_token",
@@ -334,6 +328,9 @@ def parse_args():
334 if args.pretrained_model_name_or_path is None: 328 if args.pretrained_model_name_or_path is None:
335 raise ValueError("You must specify --pretrained_model_name_or_path") 329 raise ValueError("You must specify --pretrained_model_name_or_path")
336 330
331 if args.project is None:
332 raise ValueError("You must specify --project")
333
337 if isinstance(args.initializer_token, str): 334 if isinstance(args.initializer_token, str):
338 args.initializer_token = [args.initializer_token] 335 args.initializer_token = [args.initializer_token]
339 336
@@ -366,7 +363,6 @@ class Checkpointer(CheckpointerBase):
366 text_encoder, 363 text_encoder,
367 scheduler, 364 scheduler,
368 text_embeddings, 365 text_embeddings,
369 instance_identifier,
370 placeholder_token, 366 placeholder_token,
371 placeholder_token_id, 367 placeholder_token_id,
372 output_dir: Path, 368 output_dir: Path,
@@ -378,7 +374,6 @@ class Checkpointer(CheckpointerBase):
378 super().__init__( 374 super().__init__(
379 datamodule=datamodule, 375 datamodule=datamodule,
380 output_dir=output_dir, 376 output_dir=output_dir,
381 instance_identifier=instance_identifier,
382 placeholder_token=placeholder_token, 377 placeholder_token=placeholder_token,
383 placeholder_token_id=placeholder_token_id, 378 placeholder_token_id=placeholder_token_id,
384 sample_image_size=sample_image_size, 379 sample_image_size=sample_image_size,
@@ -441,14 +436,9 @@ class Checkpointer(CheckpointerBase):
441def main(): 436def main():
442 args = parse_args() 437 args = parse_args()
443 438
444 instance_identifier = args.instance_identifier
445
446 if len(args.placeholder_token) != 0:
447 instance_identifier = instance_identifier.format(args.placeholder_token[0])
448
449 global_step_offset = args.global_step 439 global_step_offset = args.global_step
450 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 440 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
451 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 441 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
452 basepath.mkdir(parents=True, exist_ok=True) 442 basepath.mkdir(parents=True, exist_ok=True)
453 443
454 accelerator = Accelerator( 444 accelerator = Accelerator(
@@ -567,6 +557,7 @@ def main():
567 557
568 def collate_fn(examples): 558 def collate_fn(examples):
569 prompts = [example["prompts"] for example in examples] 559 prompts = [example["prompts"] for example in examples]
560 cprompts = [example["cprompts"] for example in examples]
570 nprompts = [example["nprompts"] for example in examples] 561 nprompts = [example["nprompts"] for example in examples]
571 input_ids = [example["instance_prompt_ids"] for example in examples] 562 input_ids = [example["instance_prompt_ids"] for example in examples]
572 pixel_values = [example["instance_images"] for example in examples] 563 pixel_values = [example["instance_images"] for example in examples]
@@ -583,6 +574,7 @@ def main():
583 574
584 batch = { 575 batch = {
585 "prompts": prompts, 576 "prompts": prompts,
577 "cprompts": cprompts,
586 "nprompts": nprompts, 578 "nprompts": nprompts,
587 "input_ids": inputs.input_ids, 579 "input_ids": inputs.input_ids,
588 "pixel_values": pixel_values, 580 "pixel_values": pixel_values,
@@ -594,8 +586,6 @@ def main():
594 data_file=args.train_data_file, 586 data_file=args.train_data_file,
595 batch_size=args.train_batch_size, 587 batch_size=args.train_batch_size,
596 prompt_processor=prompt_processor, 588 prompt_processor=prompt_processor,
597 instance_identifier=args.instance_identifier,
598 class_identifier=args.class_identifier,
599 class_subdir="cls", 589 class_subdir="cls",
600 num_class_images=args.num_class_images, 590 num_class_images=args.num_class_images,
601 size=args.resolution, 591 size=args.resolution,
@@ -634,7 +624,7 @@ def main():
634 with torch.autocast("cuda"), torch.inference_mode(): 624 with torch.autocast("cuda"), torch.inference_mode():
635 for batch in batched_data: 625 for batch in batched_data:
636 image_name = [item.class_image_path for item in batch] 626 image_name = [item.class_image_path for item in batch]
637 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] 627 prompt = [item.cprompt for item in batch]
638 nprompt = [item.nprompt for item in batch] 628 nprompt = [item.nprompt for item in batch]
639 629
640 images = pipeline( 630 images = pipeline(
@@ -744,7 +734,6 @@ def main():
744 text_encoder=text_encoder, 734 text_encoder=text_encoder,
745 scheduler=checkpoint_scheduler, 735 scheduler=checkpoint_scheduler,
746 text_embeddings=text_embeddings, 736 text_embeddings=text_embeddings,
747 instance_identifier=args.instance_identifier,
748 placeholder_token=args.placeholder_token, 737 placeholder_token=args.placeholder_token,
749 placeholder_token_id=placeholder_token_id, 738 placeholder_token_id=placeholder_token_id,
750 output_dir=basepath, 739 output_dir=basepath,
diff --git a/training/util.py b/training/util.py
index e8d22ae..000173d 100644
--- a/training/util.py
+++ b/training/util.py
@@ -44,7 +44,6 @@ class CheckpointerBase:
44 self, 44 self,
45 datamodule, 45 datamodule,
46 output_dir: Path, 46 output_dir: Path,
47 instance_identifier,
48 placeholder_token, 47 placeholder_token,
49 placeholder_token_id, 48 placeholder_token_id,
50 sample_image_size, 49 sample_image_size,
@@ -54,7 +53,6 @@ class CheckpointerBase:
54 ): 53 ):
55 self.datamodule = datamodule 54 self.datamodule = datamodule
56 self.output_dir = output_dir 55 self.output_dir = output_dir
57 self.instance_identifier = instance_identifier
58 self.placeholder_token = placeholder_token 56 self.placeholder_token = placeholder_token
59 self.placeholder_token_id = placeholder_token_id 57 self.placeholder_token_id = placeholder_token_id
60 self.sample_image_size = sample_image_size 58 self.sample_image_size = sample_image_size
@@ -90,7 +88,7 @@ class CheckpointerBase:
90 if j * data.batch_size < self.sample_batch_size * self.sample_batches 88 if j * data.batch_size < self.sample_batch_size * self.sample_batches
91 ] 89 ]
92 prompts = [ 90 prompts = [
93 prompt.format(identifier=self.instance_identifier) 91 prompt
94 for batch in batches 92 for batch in batches
95 for prompt in batch["prompts"] 93 for prompt in batch["prompts"]
96 ] 94 ]