diff options
author | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
commit | 1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch) | |
tree | 42d3302610046dbc5d39d254f7a2d5d5f601aa18 | |
parent | Fix (diff) | |
download | textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2 textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip |
Improved class prompt handling
-rw-r--r-- | data/csv.py | 29 | ||||
-rw-r--r-- | environment.yaml | 2 | ||||
-rw-r--r-- | train_dreambooth.py | 30 | ||||
-rw-r--r-- | train_ti.py | 29 | ||||
-rw-r--r-- | training/util.py | 4 |
5 files changed, 31 insertions, 63 deletions
diff --git a/data/csv.py b/data/csv.py index d400757..e25dd3f 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -38,6 +38,7 @@ class CSVDataItem(NamedTuple): | |||
38 | instance_image_path: Path | 38 | instance_image_path: Path |
39 | class_image_path: Path | 39 | class_image_path: Path |
40 | prompt: list[str] | 40 | prompt: list[str] |
41 | cprompt: str | ||
41 | nprompt: str | 42 | nprompt: str |
42 | 43 | ||
43 | 44 | ||
@@ -47,8 +48,6 @@ class CSVDataModule(): | |||
47 | batch_size: int, | 48 | batch_size: int, |
48 | data_file: str, | 49 | data_file: str, |
49 | prompt_processor: PromptProcessor, | 50 | prompt_processor: PromptProcessor, |
50 | instance_identifier: str, | ||
51 | class_identifier: Optional[str] = None, | ||
52 | class_subdir: str = "cls", | 51 | class_subdir: str = "cls", |
53 | num_class_images: int = 100, | 52 | num_class_images: int = 100, |
54 | size: int = 512, | 53 | size: int = 512, |
@@ -77,8 +76,6 @@ class CSVDataModule(): | |||
77 | self.num_class_images = num_class_images | 76 | self.num_class_images = num_class_images |
78 | 77 | ||
79 | self.prompt_processor = prompt_processor | 78 | self.prompt_processor = prompt_processor |
80 | self.instance_identifier = instance_identifier | ||
81 | self.class_identifier = class_identifier | ||
82 | self.size = size | 79 | self.size = size |
83 | self.repeats = repeats | 80 | self.repeats = repeats |
84 | self.dropout = dropout | 81 | self.dropout = dropout |
@@ -96,14 +93,18 @@ class CSVDataModule(): | |||
96 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: | 93 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: |
97 | image = template["image"] if "image" in template else "{}" | 94 | image = template["image"] if "image" in template else "{}" |
98 | prompt = template["prompt"] if "prompt" in template else "{content}" | 95 | prompt = template["prompt"] if "prompt" in template else "{content}" |
96 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | ||
99 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 97 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
100 | 98 | ||
101 | return [ | 99 | return [ |
102 | CSVDataItem( | 100 | CSVDataItem( |
103 | self.data_root.joinpath(image.format(item["image"])), | 101 | self.data_root.joinpath(image.format(item["image"])), |
104 | None, | 102 | None, |
105 | prompt_to_keywords(prompt.format( | 103 | prompt_to_keywords( |
106 | **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), | 104 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
105 | expansions | ||
106 | ), | ||
107 | cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), | ||
107 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 108 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
108 | ) | 109 | ) |
109 | for item in data | 110 | for item in data |
@@ -123,6 +124,7 @@ class CSVDataModule(): | |||
123 | item.instance_image_path, | 124 | item.instance_image_path, |
124 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 125 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
125 | item.prompt, | 126 | item.prompt, |
127 | item.cprompt, | ||
126 | item.nprompt, | 128 | item.nprompt, |
127 | ) | 129 | ) |
128 | for item in items | 130 | for item in items |
@@ -160,12 +162,10 @@ class CSVDataModule(): | |||
160 | 162 | ||
161 | def setup(self, stage=None): | 163 | def setup(self, stage=None): |
162 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 164 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
163 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | ||
164 | num_class_images=self.num_class_images, | 165 | num_class_images=self.num_class_images, |
165 | size=self.size, interpolation=self.interpolation, | 166 | size=self.size, interpolation=self.interpolation, |
166 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) | 167 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) |
167 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 168 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
168 | instance_identifier=self.instance_identifier, | ||
169 | size=self.size, interpolation=self.interpolation, | 169 | size=self.size, interpolation=self.interpolation, |
170 | center_crop=self.center_crop) | 170 | center_crop=self.center_crop) |
171 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 171 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
@@ -187,9 +187,7 @@ class CSVDataset(Dataset): | |||
187 | self, | 187 | self, |
188 | data: List[CSVDataItem], | 188 | data: List[CSVDataItem], |
189 | prompt_processor: PromptProcessor, | 189 | prompt_processor: PromptProcessor, |
190 | instance_identifier: str, | ||
191 | batch_size: int = 1, | 190 | batch_size: int = 1, |
192 | class_identifier: Optional[str] = None, | ||
193 | num_class_images: int = 0, | 191 | num_class_images: int = 0, |
194 | size: int = 512, | 192 | size: int = 512, |
195 | repeats: int = 1, | 193 | repeats: int = 1, |
@@ -201,8 +199,6 @@ class CSVDataset(Dataset): | |||
201 | self.data = data | 199 | self.data = data |
202 | self.prompt_processor = prompt_processor | 200 | self.prompt_processor = prompt_processor |
203 | self.batch_size = batch_size | 201 | self.batch_size = batch_size |
204 | self.instance_identifier = instance_identifier | ||
205 | self.class_identifier = class_identifier | ||
206 | self.num_class_images = num_class_images | 202 | self.num_class_images = num_class_images |
207 | self.dropout = dropout | 203 | self.dropout = dropout |
208 | self.image_cache = {} | 204 | self.image_cache = {} |
@@ -239,14 +235,12 @@ class CSVDataset(Dataset): | |||
239 | 235 | ||
240 | return image | 236 | return image |
241 | 237 | ||
242 | def get_input_ids(self, prompt, identifier): | ||
243 | return self.prompt_processor.get_input_ids(prompt.format(identifier)) | ||
244 | |||
245 | def get_example(self, i): | 238 | def get_example(self, i): |
246 | item = self.data[i % self.num_instance_images] | 239 | item = self.data[i % self.num_instance_images] |
247 | 240 | ||
248 | example = {} | 241 | example = {} |
249 | example["prompts"] = item.prompt | 242 | example["prompts"] = item.prompt |
243 | example["cprompts"] = item.cprompt | ||
250 | example["nprompts"] = item.nprompt | 244 | example["nprompts"] = item.nprompt |
251 | example["instance_images"] = self.get_image(item.instance_image_path) | 245 | example["instance_images"] = self.get_image(item.instance_image_path) |
252 | if self.num_class_images != 0: | 246 | if self.num_class_images != 0: |
@@ -260,13 +254,14 @@ class CSVDataset(Dataset): | |||
260 | example = {} | 254 | example = {} |
261 | 255 | ||
262 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) | 256 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) |
257 | example["cprompts"] = unprocessed_example["cprompts"] | ||
263 | example["nprompts"] = unprocessed_example["nprompts"] | 258 | example["nprompts"] = unprocessed_example["nprompts"] |
264 | 259 | ||
265 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 260 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
266 | example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) | 261 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) |
267 | 262 | ||
268 | if self.num_class_images != 0: | 263 | if self.num_class_images != 0: |
269 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 264 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
270 | example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) | 265 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) |
271 | 266 | ||
272 | return example | 267 | return example |
diff --git a/environment.yaml b/environment.yaml index e598f72..179fa38 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -21,4 +21,4 @@ dependencies: | |||
21 | - test-tube>=0.7.5 | 21 | - test-tube>=0.7.5 |
22 | - transformers==4.25.1 | 22 | - transformers==4.25.1 |
23 | - triton==2.0.0.dev20221202 | 23 | - triton==2.0.0.dev20221202 |
24 | - xformers==0.0.16rc391 | 24 | - xformers==0.0.16rc396 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9749c62..ff67d12 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -62,16 +62,10 @@ def parse_args(): | |||
62 | default="template", | 62 | default="template", |
63 | ) | 63 | ) |
64 | parser.add_argument( | 64 | parser.add_argument( |
65 | "--instance_identifier", | 65 | "--project", |
66 | type=str, | 66 | type=str, |
67 | default=None, | 67 | default=None, |
68 | help="A token to use as a placeholder for the concept.", | 68 | help="The name of the current project.", |
69 | ) | ||
70 | parser.add_argument( | ||
71 | "--class_identifier", | ||
72 | type=str, | ||
73 | default=None, | ||
74 | help="A token to use as a placeholder for the concept.", | ||
75 | ) | 69 | ) |
76 | parser.add_argument( | 70 | parser.add_argument( |
77 | "--placeholder_token", | 71 | "--placeholder_token", |
@@ -364,8 +358,8 @@ def parse_args(): | |||
364 | if args.pretrained_model_name_or_path is None: | 358 | if args.pretrained_model_name_or_path is None: |
365 | raise ValueError("You must specify --pretrained_model_name_or_path") | 359 | raise ValueError("You must specify --pretrained_model_name_or_path") |
366 | 360 | ||
367 | if args.instance_identifier is None: | 361 | if args.project is None: |
368 | raise ValueError("You must specify --instance_identifier") | 362 | raise ValueError("You must specify --project") |
369 | 363 | ||
370 | if isinstance(args.initializer_token, str): | 364 | if isinstance(args.initializer_token, str): |
371 | args.initializer_token = [args.initializer_token] | 365 | args.initializer_token = [args.initializer_token] |
@@ -397,7 +391,6 @@ class Checkpointer(CheckpointerBase): | |||
397 | text_encoder, | 391 | text_encoder, |
398 | scheduler, | 392 | scheduler, |
399 | output_dir: Path, | 393 | output_dir: Path, |
400 | instance_identifier, | ||
401 | placeholder_token, | 394 | placeholder_token, |
402 | placeholder_token_id, | 395 | placeholder_token_id, |
403 | sample_image_size, | 396 | sample_image_size, |
@@ -408,7 +401,6 @@ class Checkpointer(CheckpointerBase): | |||
408 | super().__init__( | 401 | super().__init__( |
409 | datamodule=datamodule, | 402 | datamodule=datamodule, |
410 | output_dir=output_dir, | 403 | output_dir=output_dir, |
411 | instance_identifier=instance_identifier, | ||
412 | placeholder_token=placeholder_token, | 404 | placeholder_token=placeholder_token, |
413 | placeholder_token_id=placeholder_token_id, | 405 | placeholder_token_id=placeholder_token_id, |
414 | sample_image_size=sample_image_size, | 406 | sample_image_size=sample_image_size, |
@@ -481,13 +473,8 @@ def main(): | |||
481 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | 473 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
482 | ) | 474 | ) |
483 | 475 | ||
484 | instance_identifier = args.instance_identifier | ||
485 | |||
486 | if len(args.placeholder_token) != 0: | ||
487 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
488 | |||
489 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
490 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 477 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
491 | basepath.mkdir(parents=True, exist_ok=True) | 478 | basepath.mkdir(parents=True, exist_ok=True) |
492 | 479 | ||
493 | accelerator = Accelerator( | 480 | accelerator = Accelerator( |
@@ -637,6 +624,7 @@ def main(): | |||
637 | 624 | ||
638 | def collate_fn(examples): | 625 | def collate_fn(examples): |
639 | prompts = [example["prompts"] for example in examples] | 626 | prompts = [example["prompts"] for example in examples] |
627 | cprompts = [example["cprompts"] for example in examples] | ||
640 | nprompts = [example["nprompts"] for example in examples] | 628 | nprompts = [example["nprompts"] for example in examples] |
641 | input_ids = [example["instance_prompt_ids"] for example in examples] | 629 | input_ids = [example["instance_prompt_ids"] for example in examples] |
642 | pixel_values = [example["instance_images"] for example in examples] | 630 | pixel_values = [example["instance_images"] for example in examples] |
@@ -653,6 +641,7 @@ def main(): | |||
653 | 641 | ||
654 | batch = { | 642 | batch = { |
655 | "prompts": prompts, | 643 | "prompts": prompts, |
644 | "cprompts": cprompts, | ||
656 | "nprompts": nprompts, | 645 | "nprompts": nprompts, |
657 | "input_ids": inputs.input_ids, | 646 | "input_ids": inputs.input_ids, |
658 | "pixel_values": pixel_values, | 647 | "pixel_values": pixel_values, |
@@ -664,8 +653,6 @@ def main(): | |||
664 | data_file=args.train_data_file, | 653 | data_file=args.train_data_file, |
665 | batch_size=args.train_batch_size, | 654 | batch_size=args.train_batch_size, |
666 | prompt_processor=prompt_processor, | 655 | prompt_processor=prompt_processor, |
667 | instance_identifier=instance_identifier, | ||
668 | class_identifier=args.class_identifier, | ||
669 | class_subdir="cls", | 656 | class_subdir="cls", |
670 | num_class_images=args.num_class_images, | 657 | num_class_images=args.num_class_images, |
671 | size=args.resolution, | 658 | size=args.resolution, |
@@ -703,7 +690,7 @@ def main(): | |||
703 | with torch.autocast("cuda"), torch.inference_mode(): | 690 | with torch.autocast("cuda"), torch.inference_mode(): |
704 | for batch in batched_data: | 691 | for batch in batched_data: |
705 | image_name = [item.class_image_path for item in batch] | 692 | image_name = [item.class_image_path for item in batch] |
706 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] | 693 | prompt = [item.cprompt for item in batch] |
707 | nprompt = [item.nprompt for item in batch] | 694 | nprompt = [item.nprompt for item in batch] |
708 | 695 | ||
709 | images = pipeline( | 696 | images = pipeline( |
@@ -812,7 +799,6 @@ def main(): | |||
812 | text_encoder=text_encoder, | 799 | text_encoder=text_encoder, |
813 | scheduler=checkpoint_scheduler, | 800 | scheduler=checkpoint_scheduler, |
814 | output_dir=basepath, | 801 | output_dir=basepath, |
815 | instance_identifier=instance_identifier, | ||
816 | placeholder_token=args.placeholder_token, | 802 | placeholder_token=args.placeholder_token, |
817 | placeholder_token_id=placeholder_token_id, | 803 | placeholder_token_id=placeholder_token_id, |
818 | sample_image_size=args.sample_image_size, | 804 | sample_image_size=args.sample_image_size, |
diff --git a/train_ti.py b/train_ti.py index 9d06c50..55daa35 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -63,16 +63,10 @@ def parse_args(): | |||
63 | default="template", | 63 | default="template", |
64 | ) | 64 | ) |
65 | parser.add_argument( | 65 | parser.add_argument( |
66 | "--instance_identifier", | 66 | "--project", |
67 | type=str, | 67 | type=str, |
68 | default=None, | 68 | default=None, |
69 | help="A token to use as a placeholder for the concept.", | 69 | help="The name of the current project.", |
70 | ) | ||
71 | parser.add_argument( | ||
72 | "--class_identifier", | ||
73 | type=str, | ||
74 | default=None, | ||
75 | help="A token to use as a placeholder for the concept.", | ||
76 | ) | 70 | ) |
77 | parser.add_argument( | 71 | parser.add_argument( |
78 | "--placeholder_token", | 72 | "--placeholder_token", |
@@ -334,6 +328,9 @@ def parse_args(): | |||
334 | if args.pretrained_model_name_or_path is None: | 328 | if args.pretrained_model_name_or_path is None: |
335 | raise ValueError("You must specify --pretrained_model_name_or_path") | 329 | raise ValueError("You must specify --pretrained_model_name_or_path") |
336 | 330 | ||
331 | if args.project is None: | ||
332 | raise ValueError("You must specify --project") | ||
333 | |||
337 | if isinstance(args.initializer_token, str): | 334 | if isinstance(args.initializer_token, str): |
338 | args.initializer_token = [args.initializer_token] | 335 | args.initializer_token = [args.initializer_token] |
339 | 336 | ||
@@ -366,7 +363,6 @@ class Checkpointer(CheckpointerBase): | |||
366 | text_encoder, | 363 | text_encoder, |
367 | scheduler, | 364 | scheduler, |
368 | text_embeddings, | 365 | text_embeddings, |
369 | instance_identifier, | ||
370 | placeholder_token, | 366 | placeholder_token, |
371 | placeholder_token_id, | 367 | placeholder_token_id, |
372 | output_dir: Path, | 368 | output_dir: Path, |
@@ -378,7 +374,6 @@ class Checkpointer(CheckpointerBase): | |||
378 | super().__init__( | 374 | super().__init__( |
379 | datamodule=datamodule, | 375 | datamodule=datamodule, |
380 | output_dir=output_dir, | 376 | output_dir=output_dir, |
381 | instance_identifier=instance_identifier, | ||
382 | placeholder_token=placeholder_token, | 377 | placeholder_token=placeholder_token, |
383 | placeholder_token_id=placeholder_token_id, | 378 | placeholder_token_id=placeholder_token_id, |
384 | sample_image_size=sample_image_size, | 379 | sample_image_size=sample_image_size, |
@@ -441,14 +436,9 @@ class Checkpointer(CheckpointerBase): | |||
441 | def main(): | 436 | def main(): |
442 | args = parse_args() | 437 | args = parse_args() |
443 | 438 | ||
444 | instance_identifier = args.instance_identifier | ||
445 | |||
446 | if len(args.placeholder_token) != 0: | ||
447 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
448 | |||
449 | global_step_offset = args.global_step | 439 | global_step_offset = args.global_step |
450 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 440 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
451 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 441 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
452 | basepath.mkdir(parents=True, exist_ok=True) | 442 | basepath.mkdir(parents=True, exist_ok=True) |
453 | 443 | ||
454 | accelerator = Accelerator( | 444 | accelerator = Accelerator( |
@@ -567,6 +557,7 @@ def main(): | |||
567 | 557 | ||
568 | def collate_fn(examples): | 558 | def collate_fn(examples): |
569 | prompts = [example["prompts"] for example in examples] | 559 | prompts = [example["prompts"] for example in examples] |
560 | cprompts = [example["cprompts"] for example in examples] | ||
570 | nprompts = [example["nprompts"] for example in examples] | 561 | nprompts = [example["nprompts"] for example in examples] |
571 | input_ids = [example["instance_prompt_ids"] for example in examples] | 562 | input_ids = [example["instance_prompt_ids"] for example in examples] |
572 | pixel_values = [example["instance_images"] for example in examples] | 563 | pixel_values = [example["instance_images"] for example in examples] |
@@ -583,6 +574,7 @@ def main(): | |||
583 | 574 | ||
584 | batch = { | 575 | batch = { |
585 | "prompts": prompts, | 576 | "prompts": prompts, |
577 | "cprompts": cprompts, | ||
586 | "nprompts": nprompts, | 578 | "nprompts": nprompts, |
587 | "input_ids": inputs.input_ids, | 579 | "input_ids": inputs.input_ids, |
588 | "pixel_values": pixel_values, | 580 | "pixel_values": pixel_values, |
@@ -594,8 +586,6 @@ def main(): | |||
594 | data_file=args.train_data_file, | 586 | data_file=args.train_data_file, |
595 | batch_size=args.train_batch_size, | 587 | batch_size=args.train_batch_size, |
596 | prompt_processor=prompt_processor, | 588 | prompt_processor=prompt_processor, |
597 | instance_identifier=args.instance_identifier, | ||
598 | class_identifier=args.class_identifier, | ||
599 | class_subdir="cls", | 589 | class_subdir="cls", |
600 | num_class_images=args.num_class_images, | 590 | num_class_images=args.num_class_images, |
601 | size=args.resolution, | 591 | size=args.resolution, |
@@ -634,7 +624,7 @@ def main(): | |||
634 | with torch.autocast("cuda"), torch.inference_mode(): | 624 | with torch.autocast("cuda"), torch.inference_mode(): |
635 | for batch in batched_data: | 625 | for batch in batched_data: |
636 | image_name = [item.class_image_path for item in batch] | 626 | image_name = [item.class_image_path for item in batch] |
637 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] | 627 | prompt = [item.cprompt for item in batch] |
638 | nprompt = [item.nprompt for item in batch] | 628 | nprompt = [item.nprompt for item in batch] |
639 | 629 | ||
640 | images = pipeline( | 630 | images = pipeline( |
@@ -744,7 +734,6 @@ def main(): | |||
744 | text_encoder=text_encoder, | 734 | text_encoder=text_encoder, |
745 | scheduler=checkpoint_scheduler, | 735 | scheduler=checkpoint_scheduler, |
746 | text_embeddings=text_embeddings, | 736 | text_embeddings=text_embeddings, |
747 | instance_identifier=args.instance_identifier, | ||
748 | placeholder_token=args.placeholder_token, | 737 | placeholder_token=args.placeholder_token, |
749 | placeholder_token_id=placeholder_token_id, | 738 | placeholder_token_id=placeholder_token_id, |
750 | output_dir=basepath, | 739 | output_dir=basepath, |
diff --git a/training/util.py b/training/util.py index e8d22ae..000173d 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -44,7 +44,6 @@ class CheckpointerBase: | |||
44 | self, | 44 | self, |
45 | datamodule, | 45 | datamodule, |
46 | output_dir: Path, | 46 | output_dir: Path, |
47 | instance_identifier, | ||
48 | placeholder_token, | 47 | placeholder_token, |
49 | placeholder_token_id, | 48 | placeholder_token_id, |
50 | sample_image_size, | 49 | sample_image_size, |
@@ -54,7 +53,6 @@ class CheckpointerBase: | |||
54 | ): | 53 | ): |
55 | self.datamodule = datamodule | 54 | self.datamodule = datamodule |
56 | self.output_dir = output_dir | 55 | self.output_dir = output_dir |
57 | self.instance_identifier = instance_identifier | ||
58 | self.placeholder_token = placeholder_token | 56 | self.placeholder_token = placeholder_token |
59 | self.placeholder_token_id = placeholder_token_id | 57 | self.placeholder_token_id = placeholder_token_id |
60 | self.sample_image_size = sample_image_size | 58 | self.sample_image_size = sample_image_size |
@@ -90,7 +88,7 @@ class CheckpointerBase: | |||
90 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | 88 | if j * data.batch_size < self.sample_batch_size * self.sample_batches |
91 | ] | 89 | ] |
92 | prompts = [ | 90 | prompts = [ |
93 | prompt.format(identifier=self.instance_identifier) | 91 | prompt |
94 | for batch in batches | 92 | for batch in batches |
95 | for prompt in batch["prompts"] | 93 | for prompt in batch["prompts"] |
96 | ] | 94 | ] |