diff options
| -rw-r--r-- | data/csv.py | 29 | ||||
| -rw-r--r-- | environment.yaml | 2 | ||||
| -rw-r--r-- | train_dreambooth.py | 30 | ||||
| -rw-r--r-- | train_ti.py | 29 | ||||
| -rw-r--r-- | training/util.py | 4 |
5 files changed, 31 insertions, 63 deletions
diff --git a/data/csv.py b/data/csv.py index d400757..e25dd3f 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -38,6 +38,7 @@ class CSVDataItem(NamedTuple): | |||
| 38 | instance_image_path: Path | 38 | instance_image_path: Path |
| 39 | class_image_path: Path | 39 | class_image_path: Path |
| 40 | prompt: list[str] | 40 | prompt: list[str] |
| 41 | cprompt: str | ||
| 41 | nprompt: str | 42 | nprompt: str |
| 42 | 43 | ||
| 43 | 44 | ||
| @@ -47,8 +48,6 @@ class CSVDataModule(): | |||
| 47 | batch_size: int, | 48 | batch_size: int, |
| 48 | data_file: str, | 49 | data_file: str, |
| 49 | prompt_processor: PromptProcessor, | 50 | prompt_processor: PromptProcessor, |
| 50 | instance_identifier: str, | ||
| 51 | class_identifier: Optional[str] = None, | ||
| 52 | class_subdir: str = "cls", | 51 | class_subdir: str = "cls", |
| 53 | num_class_images: int = 100, | 52 | num_class_images: int = 100, |
| 54 | size: int = 512, | 53 | size: int = 512, |
| @@ -77,8 +76,6 @@ class CSVDataModule(): | |||
| 77 | self.num_class_images = num_class_images | 76 | self.num_class_images = num_class_images |
| 78 | 77 | ||
| 79 | self.prompt_processor = prompt_processor | 78 | self.prompt_processor = prompt_processor |
| 80 | self.instance_identifier = instance_identifier | ||
| 81 | self.class_identifier = class_identifier | ||
| 82 | self.size = size | 79 | self.size = size |
| 83 | self.repeats = repeats | 80 | self.repeats = repeats |
| 84 | self.dropout = dropout | 81 | self.dropout = dropout |
| @@ -96,14 +93,18 @@ class CSVDataModule(): | |||
| 96 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: | 93 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: |
| 97 | image = template["image"] if "image" in template else "{}" | 94 | image = template["image"] if "image" in template else "{}" |
| 98 | prompt = template["prompt"] if "prompt" in template else "{content}" | 95 | prompt = template["prompt"] if "prompt" in template else "{content}" |
| 96 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | ||
| 99 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 97 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
| 100 | 98 | ||
| 101 | return [ | 99 | return [ |
| 102 | CSVDataItem( | 100 | CSVDataItem( |
| 103 | self.data_root.joinpath(image.format(item["image"])), | 101 | self.data_root.joinpath(image.format(item["image"])), |
| 104 | None, | 102 | None, |
| 105 | prompt_to_keywords(prompt.format( | 103 | prompt_to_keywords( |
| 106 | **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), | 104 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 105 | expansions | ||
| 106 | ), | ||
| 107 | cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), | ||
| 107 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 108 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 108 | ) | 109 | ) |
| 109 | for item in data | 110 | for item in data |
| @@ -123,6 +124,7 @@ class CSVDataModule(): | |||
| 123 | item.instance_image_path, | 124 | item.instance_image_path, |
| 124 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 125 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
| 125 | item.prompt, | 126 | item.prompt, |
| 127 | item.cprompt, | ||
| 126 | item.nprompt, | 128 | item.nprompt, |
| 127 | ) | 129 | ) |
| 128 | for item in items | 130 | for item in items |
| @@ -160,12 +162,10 @@ class CSVDataModule(): | |||
| 160 | 162 | ||
| 161 | def setup(self, stage=None): | 163 | def setup(self, stage=None): |
| 162 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 164 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
| 163 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | ||
| 164 | num_class_images=self.num_class_images, | 165 | num_class_images=self.num_class_images, |
| 165 | size=self.size, interpolation=self.interpolation, | 166 | size=self.size, interpolation=self.interpolation, |
| 166 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) | 167 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) |
| 167 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 168 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
| 168 | instance_identifier=self.instance_identifier, | ||
| 169 | size=self.size, interpolation=self.interpolation, | 169 | size=self.size, interpolation=self.interpolation, |
| 170 | center_crop=self.center_crop) | 170 | center_crop=self.center_crop) |
| 171 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 171 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| @@ -187,9 +187,7 @@ class CSVDataset(Dataset): | |||
| 187 | self, | 187 | self, |
| 188 | data: List[CSVDataItem], | 188 | data: List[CSVDataItem], |
| 189 | prompt_processor: PromptProcessor, | 189 | prompt_processor: PromptProcessor, |
| 190 | instance_identifier: str, | ||
| 191 | batch_size: int = 1, | 190 | batch_size: int = 1, |
| 192 | class_identifier: Optional[str] = None, | ||
| 193 | num_class_images: int = 0, | 191 | num_class_images: int = 0, |
| 194 | size: int = 512, | 192 | size: int = 512, |
| 195 | repeats: int = 1, | 193 | repeats: int = 1, |
| @@ -201,8 +199,6 @@ class CSVDataset(Dataset): | |||
| 201 | self.data = data | 199 | self.data = data |
| 202 | self.prompt_processor = prompt_processor | 200 | self.prompt_processor = prompt_processor |
| 203 | self.batch_size = batch_size | 201 | self.batch_size = batch_size |
| 204 | self.instance_identifier = instance_identifier | ||
| 205 | self.class_identifier = class_identifier | ||
| 206 | self.num_class_images = num_class_images | 202 | self.num_class_images = num_class_images |
| 207 | self.dropout = dropout | 203 | self.dropout = dropout |
| 208 | self.image_cache = {} | 204 | self.image_cache = {} |
| @@ -239,14 +235,12 @@ class CSVDataset(Dataset): | |||
| 239 | 235 | ||
| 240 | return image | 236 | return image |
| 241 | 237 | ||
| 242 | def get_input_ids(self, prompt, identifier): | ||
| 243 | return self.prompt_processor.get_input_ids(prompt.format(identifier)) | ||
| 244 | |||
| 245 | def get_example(self, i): | 238 | def get_example(self, i): |
| 246 | item = self.data[i % self.num_instance_images] | 239 | item = self.data[i % self.num_instance_images] |
| 247 | 240 | ||
| 248 | example = {} | 241 | example = {} |
| 249 | example["prompts"] = item.prompt | 242 | example["prompts"] = item.prompt |
| 243 | example["cprompts"] = item.cprompt | ||
| 250 | example["nprompts"] = item.nprompt | 244 | example["nprompts"] = item.nprompt |
| 251 | example["instance_images"] = self.get_image(item.instance_image_path) | 245 | example["instance_images"] = self.get_image(item.instance_image_path) |
| 252 | if self.num_class_images != 0: | 246 | if self.num_class_images != 0: |
| @@ -260,13 +254,14 @@ class CSVDataset(Dataset): | |||
| 260 | example = {} | 254 | example = {} |
| 261 | 255 | ||
| 262 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) | 256 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) |
| 257 | example["cprompts"] = unprocessed_example["cprompts"] | ||
| 263 | example["nprompts"] = unprocessed_example["nprompts"] | 258 | example["nprompts"] = unprocessed_example["nprompts"] |
| 264 | 259 | ||
| 265 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 260 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 266 | example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) | 261 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) |
| 267 | 262 | ||
| 268 | if self.num_class_images != 0: | 263 | if self.num_class_images != 0: |
| 269 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 264 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
| 270 | example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) | 265 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) |
| 271 | 266 | ||
| 272 | return example | 267 | return example |
diff --git a/environment.yaml b/environment.yaml index e598f72..179fa38 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -21,4 +21,4 @@ dependencies: | |||
| 21 | - test-tube>=0.7.5 | 21 | - test-tube>=0.7.5 |
| 22 | - transformers==4.25.1 | 22 | - transformers==4.25.1 |
| 23 | - triton==2.0.0.dev20221202 | 23 | - triton==2.0.0.dev20221202 |
| 24 | - xformers==0.0.16rc391 | 24 | - xformers==0.0.16rc396 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9749c62..ff67d12 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -62,16 +62,10 @@ def parse_args(): | |||
| 62 | default="template", | 62 | default="template", |
| 63 | ) | 63 | ) |
| 64 | parser.add_argument( | 64 | parser.add_argument( |
| 65 | "--instance_identifier", | 65 | "--project", |
| 66 | type=str, | 66 | type=str, |
| 67 | default=None, | 67 | default=None, |
| 68 | help="A token to use as a placeholder for the concept.", | 68 | help="The name of the current project.", |
| 69 | ) | ||
| 70 | parser.add_argument( | ||
| 71 | "--class_identifier", | ||
| 72 | type=str, | ||
| 73 | default=None, | ||
| 74 | help="A token to use as a placeholder for the concept.", | ||
| 75 | ) | 69 | ) |
| 76 | parser.add_argument( | 70 | parser.add_argument( |
| 77 | "--placeholder_token", | 71 | "--placeholder_token", |
| @@ -364,8 +358,8 @@ def parse_args(): | |||
| 364 | if args.pretrained_model_name_or_path is None: | 358 | if args.pretrained_model_name_or_path is None: |
| 365 | raise ValueError("You must specify --pretrained_model_name_or_path") | 359 | raise ValueError("You must specify --pretrained_model_name_or_path") |
| 366 | 360 | ||
| 367 | if args.instance_identifier is None: | 361 | if args.project is None: |
| 368 | raise ValueError("You must specify --instance_identifier") | 362 | raise ValueError("You must specify --project") |
| 369 | 363 | ||
| 370 | if isinstance(args.initializer_token, str): | 364 | if isinstance(args.initializer_token, str): |
| 371 | args.initializer_token = [args.initializer_token] | 365 | args.initializer_token = [args.initializer_token] |
| @@ -397,7 +391,6 @@ class Checkpointer(CheckpointerBase): | |||
| 397 | text_encoder, | 391 | text_encoder, |
| 398 | scheduler, | 392 | scheduler, |
| 399 | output_dir: Path, | 393 | output_dir: Path, |
| 400 | instance_identifier, | ||
| 401 | placeholder_token, | 394 | placeholder_token, |
| 402 | placeholder_token_id, | 395 | placeholder_token_id, |
| 403 | sample_image_size, | 396 | sample_image_size, |
| @@ -408,7 +401,6 @@ class Checkpointer(CheckpointerBase): | |||
| 408 | super().__init__( | 401 | super().__init__( |
| 409 | datamodule=datamodule, | 402 | datamodule=datamodule, |
| 410 | output_dir=output_dir, | 403 | output_dir=output_dir, |
| 411 | instance_identifier=instance_identifier, | ||
| 412 | placeholder_token=placeholder_token, | 404 | placeholder_token=placeholder_token, |
| 413 | placeholder_token_id=placeholder_token_id, | 405 | placeholder_token_id=placeholder_token_id, |
| 414 | sample_image_size=sample_image_size, | 406 | sample_image_size=sample_image_size, |
| @@ -481,13 +473,8 @@ def main(): | |||
| 481 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | 473 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
| 482 | ) | 474 | ) |
| 483 | 475 | ||
| 484 | instance_identifier = args.instance_identifier | ||
| 485 | |||
| 486 | if len(args.placeholder_token) != 0: | ||
| 487 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
| 488 | |||
| 489 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 490 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 477 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 491 | basepath.mkdir(parents=True, exist_ok=True) | 478 | basepath.mkdir(parents=True, exist_ok=True) |
| 492 | 479 | ||
| 493 | accelerator = Accelerator( | 480 | accelerator = Accelerator( |
| @@ -637,6 +624,7 @@ def main(): | |||
| 637 | 624 | ||
| 638 | def collate_fn(examples): | 625 | def collate_fn(examples): |
| 639 | prompts = [example["prompts"] for example in examples] | 626 | prompts = [example["prompts"] for example in examples] |
| 627 | cprompts = [example["cprompts"] for example in examples] | ||
| 640 | nprompts = [example["nprompts"] for example in examples] | 628 | nprompts = [example["nprompts"] for example in examples] |
| 641 | input_ids = [example["instance_prompt_ids"] for example in examples] | 629 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 642 | pixel_values = [example["instance_images"] for example in examples] | 630 | pixel_values = [example["instance_images"] for example in examples] |
| @@ -653,6 +641,7 @@ def main(): | |||
| 653 | 641 | ||
| 654 | batch = { | 642 | batch = { |
| 655 | "prompts": prompts, | 643 | "prompts": prompts, |
| 644 | "cprompts": cprompts, | ||
| 656 | "nprompts": nprompts, | 645 | "nprompts": nprompts, |
| 657 | "input_ids": inputs.input_ids, | 646 | "input_ids": inputs.input_ids, |
| 658 | "pixel_values": pixel_values, | 647 | "pixel_values": pixel_values, |
| @@ -664,8 +653,6 @@ def main(): | |||
| 664 | data_file=args.train_data_file, | 653 | data_file=args.train_data_file, |
| 665 | batch_size=args.train_batch_size, | 654 | batch_size=args.train_batch_size, |
| 666 | prompt_processor=prompt_processor, | 655 | prompt_processor=prompt_processor, |
| 667 | instance_identifier=instance_identifier, | ||
| 668 | class_identifier=args.class_identifier, | ||
| 669 | class_subdir="cls", | 656 | class_subdir="cls", |
| 670 | num_class_images=args.num_class_images, | 657 | num_class_images=args.num_class_images, |
| 671 | size=args.resolution, | 658 | size=args.resolution, |
| @@ -703,7 +690,7 @@ def main(): | |||
| 703 | with torch.autocast("cuda"), torch.inference_mode(): | 690 | with torch.autocast("cuda"), torch.inference_mode(): |
| 704 | for batch in batched_data: | 691 | for batch in batched_data: |
| 705 | image_name = [item.class_image_path for item in batch] | 692 | image_name = [item.class_image_path for item in batch] |
| 706 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] | 693 | prompt = [item.cprompt for item in batch] |
| 707 | nprompt = [item.nprompt for item in batch] | 694 | nprompt = [item.nprompt for item in batch] |
| 708 | 695 | ||
| 709 | images = pipeline( | 696 | images = pipeline( |
| @@ -812,7 +799,6 @@ def main(): | |||
| 812 | text_encoder=text_encoder, | 799 | text_encoder=text_encoder, |
| 813 | scheduler=checkpoint_scheduler, | 800 | scheduler=checkpoint_scheduler, |
| 814 | output_dir=basepath, | 801 | output_dir=basepath, |
| 815 | instance_identifier=instance_identifier, | ||
| 816 | placeholder_token=args.placeholder_token, | 802 | placeholder_token=args.placeholder_token, |
| 817 | placeholder_token_id=placeholder_token_id, | 803 | placeholder_token_id=placeholder_token_id, |
| 818 | sample_image_size=args.sample_image_size, | 804 | sample_image_size=args.sample_image_size, |
diff --git a/train_ti.py b/train_ti.py index 9d06c50..55daa35 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -63,16 +63,10 @@ def parse_args(): | |||
| 63 | default="template", | 63 | default="template", |
| 64 | ) | 64 | ) |
| 65 | parser.add_argument( | 65 | parser.add_argument( |
| 66 | "--instance_identifier", | 66 | "--project", |
| 67 | type=str, | 67 | type=str, |
| 68 | default=None, | 68 | default=None, |
| 69 | help="A token to use as a placeholder for the concept.", | 69 | help="The name of the current project.", |
| 70 | ) | ||
| 71 | parser.add_argument( | ||
| 72 | "--class_identifier", | ||
| 73 | type=str, | ||
| 74 | default=None, | ||
| 75 | help="A token to use as a placeholder for the concept.", | ||
| 76 | ) | 70 | ) |
| 77 | parser.add_argument( | 71 | parser.add_argument( |
| 78 | "--placeholder_token", | 72 | "--placeholder_token", |
| @@ -334,6 +328,9 @@ def parse_args(): | |||
| 334 | if args.pretrained_model_name_or_path is None: | 328 | if args.pretrained_model_name_or_path is None: |
| 335 | raise ValueError("You must specify --pretrained_model_name_or_path") | 329 | raise ValueError("You must specify --pretrained_model_name_or_path") |
| 336 | 330 | ||
| 331 | if args.project is None: | ||
| 332 | raise ValueError("You must specify --project") | ||
| 333 | |||
| 337 | if isinstance(args.initializer_token, str): | 334 | if isinstance(args.initializer_token, str): |
| 338 | args.initializer_token = [args.initializer_token] | 335 | args.initializer_token = [args.initializer_token] |
| 339 | 336 | ||
| @@ -366,7 +363,6 @@ class Checkpointer(CheckpointerBase): | |||
| 366 | text_encoder, | 363 | text_encoder, |
| 367 | scheduler, | 364 | scheduler, |
| 368 | text_embeddings, | 365 | text_embeddings, |
| 369 | instance_identifier, | ||
| 370 | placeholder_token, | 366 | placeholder_token, |
| 371 | placeholder_token_id, | 367 | placeholder_token_id, |
| 372 | output_dir: Path, | 368 | output_dir: Path, |
| @@ -378,7 +374,6 @@ class Checkpointer(CheckpointerBase): | |||
| 378 | super().__init__( | 374 | super().__init__( |
| 379 | datamodule=datamodule, | 375 | datamodule=datamodule, |
| 380 | output_dir=output_dir, | 376 | output_dir=output_dir, |
| 381 | instance_identifier=instance_identifier, | ||
| 382 | placeholder_token=placeholder_token, | 377 | placeholder_token=placeholder_token, |
| 383 | placeholder_token_id=placeholder_token_id, | 378 | placeholder_token_id=placeholder_token_id, |
| 384 | sample_image_size=sample_image_size, | 379 | sample_image_size=sample_image_size, |
| @@ -441,14 +436,9 @@ class Checkpointer(CheckpointerBase): | |||
| 441 | def main(): | 436 | def main(): |
| 442 | args = parse_args() | 437 | args = parse_args() |
| 443 | 438 | ||
| 444 | instance_identifier = args.instance_identifier | ||
| 445 | |||
| 446 | if len(args.placeholder_token) != 0: | ||
| 447 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
| 448 | |||
| 449 | global_step_offset = args.global_step | 439 | global_step_offset = args.global_step |
| 450 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 440 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 451 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 441 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 452 | basepath.mkdir(parents=True, exist_ok=True) | 442 | basepath.mkdir(parents=True, exist_ok=True) |
| 453 | 443 | ||
| 454 | accelerator = Accelerator( | 444 | accelerator = Accelerator( |
| @@ -567,6 +557,7 @@ def main(): | |||
| 567 | 557 | ||
| 568 | def collate_fn(examples): | 558 | def collate_fn(examples): |
| 569 | prompts = [example["prompts"] for example in examples] | 559 | prompts = [example["prompts"] for example in examples] |
| 560 | cprompts = [example["cprompts"] for example in examples] | ||
| 570 | nprompts = [example["nprompts"] for example in examples] | 561 | nprompts = [example["nprompts"] for example in examples] |
| 571 | input_ids = [example["instance_prompt_ids"] for example in examples] | 562 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 572 | pixel_values = [example["instance_images"] for example in examples] | 563 | pixel_values = [example["instance_images"] for example in examples] |
| @@ -583,6 +574,7 @@ def main(): | |||
| 583 | 574 | ||
| 584 | batch = { | 575 | batch = { |
| 585 | "prompts": prompts, | 576 | "prompts": prompts, |
| 577 | "cprompts": cprompts, | ||
| 586 | "nprompts": nprompts, | 578 | "nprompts": nprompts, |
| 587 | "input_ids": inputs.input_ids, | 579 | "input_ids": inputs.input_ids, |
| 588 | "pixel_values": pixel_values, | 580 | "pixel_values": pixel_values, |
| @@ -594,8 +586,6 @@ def main(): | |||
| 594 | data_file=args.train_data_file, | 586 | data_file=args.train_data_file, |
| 595 | batch_size=args.train_batch_size, | 587 | batch_size=args.train_batch_size, |
| 596 | prompt_processor=prompt_processor, | 588 | prompt_processor=prompt_processor, |
| 597 | instance_identifier=args.instance_identifier, | ||
| 598 | class_identifier=args.class_identifier, | ||
| 599 | class_subdir="cls", | 589 | class_subdir="cls", |
| 600 | num_class_images=args.num_class_images, | 590 | num_class_images=args.num_class_images, |
| 601 | size=args.resolution, | 591 | size=args.resolution, |
| @@ -634,7 +624,7 @@ def main(): | |||
| 634 | with torch.autocast("cuda"), torch.inference_mode(): | 624 | with torch.autocast("cuda"), torch.inference_mode(): |
| 635 | for batch in batched_data: | 625 | for batch in batched_data: |
| 636 | image_name = [item.class_image_path for item in batch] | 626 | image_name = [item.class_image_path for item in batch] |
| 637 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] | 627 | prompt = [item.cprompt for item in batch] |
| 638 | nprompt = [item.nprompt for item in batch] | 628 | nprompt = [item.nprompt for item in batch] |
| 639 | 629 | ||
| 640 | images = pipeline( | 630 | images = pipeline( |
| @@ -744,7 +734,6 @@ def main(): | |||
| 744 | text_encoder=text_encoder, | 734 | text_encoder=text_encoder, |
| 745 | scheduler=checkpoint_scheduler, | 735 | scheduler=checkpoint_scheduler, |
| 746 | text_embeddings=text_embeddings, | 736 | text_embeddings=text_embeddings, |
| 747 | instance_identifier=args.instance_identifier, | ||
| 748 | placeholder_token=args.placeholder_token, | 737 | placeholder_token=args.placeholder_token, |
| 749 | placeholder_token_id=placeholder_token_id, | 738 | placeholder_token_id=placeholder_token_id, |
| 750 | output_dir=basepath, | 739 | output_dir=basepath, |
diff --git a/training/util.py b/training/util.py index e8d22ae..000173d 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -44,7 +44,6 @@ class CheckpointerBase: | |||
| 44 | self, | 44 | self, |
| 45 | datamodule, | 45 | datamodule, |
| 46 | output_dir: Path, | 46 | output_dir: Path, |
| 47 | instance_identifier, | ||
| 48 | placeholder_token, | 47 | placeholder_token, |
| 49 | placeholder_token_id, | 48 | placeholder_token_id, |
| 50 | sample_image_size, | 49 | sample_image_size, |
| @@ -54,7 +53,6 @@ class CheckpointerBase: | |||
| 54 | ): | 53 | ): |
| 55 | self.datamodule = datamodule | 54 | self.datamodule = datamodule |
| 56 | self.output_dir = output_dir | 55 | self.output_dir = output_dir |
| 57 | self.instance_identifier = instance_identifier | ||
| 58 | self.placeholder_token = placeholder_token | 56 | self.placeholder_token = placeholder_token |
| 59 | self.placeholder_token_id = placeholder_token_id | 57 | self.placeholder_token_id = placeholder_token_id |
| 60 | self.sample_image_size = sample_image_size | 58 | self.sample_image_size = sample_image_size |
| @@ -90,7 +88,7 @@ class CheckpointerBase: | |||
| 90 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | 88 | if j * data.batch_size < self.sample_batch_size * self.sample_batches |
| 91 | ] | 89 | ] |
| 92 | prompts = [ | 90 | prompts = [ |
| 93 | prompt.format(identifier=self.instance_identifier) | 91 | prompt |
| 94 | for batch in batches | 92 | for batch in batches |
| 95 | for prompt in batch["prompts"] | 93 | for prompt in batch["prompts"] |
| 96 | ] | 94 | ] |
