diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
| commit | 1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch) | |
| tree | 42d3302610046dbc5d39d254f7a2d5d5f601aa18 /train_dreambooth.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2 textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip | |
Improved class prompt handling
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 30 |
1 files changed, 8 insertions, 22 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9749c62..ff67d12 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -62,16 +62,10 @@ def parse_args(): | |||
| 62 | default="template", | 62 | default="template", |
| 63 | ) | 63 | ) |
| 64 | parser.add_argument( | 64 | parser.add_argument( |
| 65 | "--instance_identifier", | 65 | "--project", |
| 66 | type=str, | 66 | type=str, |
| 67 | default=None, | 67 | default=None, |
| 68 | help="A token to use as a placeholder for the concept.", | 68 | help="The name of the current project.", |
| 69 | ) | ||
| 70 | parser.add_argument( | ||
| 71 | "--class_identifier", | ||
| 72 | type=str, | ||
| 73 | default=None, | ||
| 74 | help="A token to use as a placeholder for the concept.", | ||
| 75 | ) | 69 | ) |
| 76 | parser.add_argument( | 70 | parser.add_argument( |
| 77 | "--placeholder_token", | 71 | "--placeholder_token", |
| @@ -364,8 +358,8 @@ def parse_args(): | |||
| 364 | if args.pretrained_model_name_or_path is None: | 358 | if args.pretrained_model_name_or_path is None: |
| 365 | raise ValueError("You must specify --pretrained_model_name_or_path") | 359 | raise ValueError("You must specify --pretrained_model_name_or_path") |
| 366 | 360 | ||
| 367 | if args.instance_identifier is None: | 361 | if args.project is None: |
| 368 | raise ValueError("You must specify --instance_identifier") | 362 | raise ValueError("You must specify --project") |
| 369 | 363 | ||
| 370 | if isinstance(args.initializer_token, str): | 364 | if isinstance(args.initializer_token, str): |
| 371 | args.initializer_token = [args.initializer_token] | 365 | args.initializer_token = [args.initializer_token] |
| @@ -397,7 +391,6 @@ class Checkpointer(CheckpointerBase): | |||
| 397 | text_encoder, | 391 | text_encoder, |
| 398 | scheduler, | 392 | scheduler, |
| 399 | output_dir: Path, | 393 | output_dir: Path, |
| 400 | instance_identifier, | ||
| 401 | placeholder_token, | 394 | placeholder_token, |
| 402 | placeholder_token_id, | 395 | placeholder_token_id, |
| 403 | sample_image_size, | 396 | sample_image_size, |
| @@ -408,7 +401,6 @@ class Checkpointer(CheckpointerBase): | |||
| 408 | super().__init__( | 401 | super().__init__( |
| 409 | datamodule=datamodule, | 402 | datamodule=datamodule, |
| 410 | output_dir=output_dir, | 403 | output_dir=output_dir, |
| 411 | instance_identifier=instance_identifier, | ||
| 412 | placeholder_token=placeholder_token, | 404 | placeholder_token=placeholder_token, |
| 413 | placeholder_token_id=placeholder_token_id, | 405 | placeholder_token_id=placeholder_token_id, |
| 414 | sample_image_size=sample_image_size, | 406 | sample_image_size=sample_image_size, |
| @@ -481,13 +473,8 @@ def main(): | |||
| 481 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | 473 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
| 482 | ) | 474 | ) |
| 483 | 475 | ||
| 484 | instance_identifier = args.instance_identifier | ||
| 485 | |||
| 486 | if len(args.placeholder_token) != 0: | ||
| 487 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
| 488 | |||
| 489 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 490 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 477 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 491 | basepath.mkdir(parents=True, exist_ok=True) | 478 | basepath.mkdir(parents=True, exist_ok=True) |
| 492 | 479 | ||
| 493 | accelerator = Accelerator( | 480 | accelerator = Accelerator( |
| @@ -637,6 +624,7 @@ def main(): | |||
| 637 | 624 | ||
| 638 | def collate_fn(examples): | 625 | def collate_fn(examples): |
| 639 | prompts = [example["prompts"] for example in examples] | 626 | prompts = [example["prompts"] for example in examples] |
| 627 | cprompts = [example["cprompts"] for example in examples] | ||
| 640 | nprompts = [example["nprompts"] for example in examples] | 628 | nprompts = [example["nprompts"] for example in examples] |
| 641 | input_ids = [example["instance_prompt_ids"] for example in examples] | 629 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 642 | pixel_values = [example["instance_images"] for example in examples] | 630 | pixel_values = [example["instance_images"] for example in examples] |
| @@ -653,6 +641,7 @@ def main(): | |||
| 653 | 641 | ||
| 654 | batch = { | 642 | batch = { |
| 655 | "prompts": prompts, | 643 | "prompts": prompts, |
| 644 | "cprompts": cprompts, | ||
| 656 | "nprompts": nprompts, | 645 | "nprompts": nprompts, |
| 657 | "input_ids": inputs.input_ids, | 646 | "input_ids": inputs.input_ids, |
| 658 | "pixel_values": pixel_values, | 647 | "pixel_values": pixel_values, |
| @@ -664,8 +653,6 @@ def main(): | |||
| 664 | data_file=args.train_data_file, | 653 | data_file=args.train_data_file, |
| 665 | batch_size=args.train_batch_size, | 654 | batch_size=args.train_batch_size, |
| 666 | prompt_processor=prompt_processor, | 655 | prompt_processor=prompt_processor, |
| 667 | instance_identifier=instance_identifier, | ||
| 668 | class_identifier=args.class_identifier, | ||
| 669 | class_subdir="cls", | 656 | class_subdir="cls", |
| 670 | num_class_images=args.num_class_images, | 657 | num_class_images=args.num_class_images, |
| 671 | size=args.resolution, | 658 | size=args.resolution, |
| @@ -703,7 +690,7 @@ def main(): | |||
| 703 | with torch.autocast("cuda"), torch.inference_mode(): | 690 | with torch.autocast("cuda"), torch.inference_mode(): |
| 704 | for batch in batched_data: | 691 | for batch in batched_data: |
| 705 | image_name = [item.class_image_path for item in batch] | 692 | image_name = [item.class_image_path for item in batch] |
| 706 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] | 693 | prompt = [item.cprompt for item in batch] |
| 707 | nprompt = [item.nprompt for item in batch] | 694 | nprompt = [item.nprompt for item in batch] |
| 708 | 695 | ||
| 709 | images = pipeline( | 696 | images = pipeline( |
| @@ -812,7 +799,6 @@ def main(): | |||
| 812 | text_encoder=text_encoder, | 799 | text_encoder=text_encoder, |
| 813 | scheduler=checkpoint_scheduler, | 800 | scheduler=checkpoint_scheduler, |
| 814 | output_dir=basepath, | 801 | output_dir=basepath, |
| 815 | instance_identifier=instance_identifier, | ||
| 816 | placeholder_token=args.placeholder_token, | 802 | placeholder_token=args.placeholder_token, |
| 817 | placeholder_token_id=placeholder_token_id, | 803 | placeholder_token_id=placeholder_token_id, |
| 818 | sample_image_size=args.sample_image_size, | 804 | sample_image_size=args.sample_image_size, |
