diff options
author | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
commit | 1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch) | |
tree | 42d3302610046dbc5d39d254f7a2d5d5f601aa18 /train_dreambooth.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2 textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip |
Improved class prompt handling
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 30 |
1 files changed, 8 insertions, 22 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9749c62..ff67d12 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -62,16 +62,10 @@ def parse_args(): | |||
62 | default="template", | 62 | default="template", |
63 | ) | 63 | ) |
64 | parser.add_argument( | 64 | parser.add_argument( |
65 | "--instance_identifier", | 65 | "--project", |
66 | type=str, | 66 | type=str, |
67 | default=None, | 67 | default=None, |
68 | help="A token to use as a placeholder for the concept.", | 68 | help="The name of the current project.", |
69 | ) | ||
70 | parser.add_argument( | ||
71 | "--class_identifier", | ||
72 | type=str, | ||
73 | default=None, | ||
74 | help="A token to use as a placeholder for the concept.", | ||
75 | ) | 69 | ) |
76 | parser.add_argument( | 70 | parser.add_argument( |
77 | "--placeholder_token", | 71 | "--placeholder_token", |
@@ -364,8 +358,8 @@ def parse_args(): | |||
364 | if args.pretrained_model_name_or_path is None: | 358 | if args.pretrained_model_name_or_path is None: |
365 | raise ValueError("You must specify --pretrained_model_name_or_path") | 359 | raise ValueError("You must specify --pretrained_model_name_or_path") |
366 | 360 | ||
367 | if args.instance_identifier is None: | 361 | if args.project is None: |
368 | raise ValueError("You must specify --instance_identifier") | 362 | raise ValueError("You must specify --project") |
369 | 363 | ||
370 | if isinstance(args.initializer_token, str): | 364 | if isinstance(args.initializer_token, str): |
371 | args.initializer_token = [args.initializer_token] | 365 | args.initializer_token = [args.initializer_token] |
@@ -397,7 +391,6 @@ class Checkpointer(CheckpointerBase): | |||
397 | text_encoder, | 391 | text_encoder, |
398 | scheduler, | 392 | scheduler, |
399 | output_dir: Path, | 393 | output_dir: Path, |
400 | instance_identifier, | ||
401 | placeholder_token, | 394 | placeholder_token, |
402 | placeholder_token_id, | 395 | placeholder_token_id, |
403 | sample_image_size, | 396 | sample_image_size, |
@@ -408,7 +401,6 @@ class Checkpointer(CheckpointerBase): | |||
408 | super().__init__( | 401 | super().__init__( |
409 | datamodule=datamodule, | 402 | datamodule=datamodule, |
410 | output_dir=output_dir, | 403 | output_dir=output_dir, |
411 | instance_identifier=instance_identifier, | ||
412 | placeholder_token=placeholder_token, | 404 | placeholder_token=placeholder_token, |
413 | placeholder_token_id=placeholder_token_id, | 405 | placeholder_token_id=placeholder_token_id, |
414 | sample_image_size=sample_image_size, | 406 | sample_image_size=sample_image_size, |
@@ -481,13 +473,8 @@ def main(): | |||
481 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | 473 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
482 | ) | 474 | ) |
483 | 475 | ||
484 | instance_identifier = args.instance_identifier | ||
485 | |||
486 | if len(args.placeholder_token) != 0: | ||
487 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
488 | |||
489 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
490 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 477 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
491 | basepath.mkdir(parents=True, exist_ok=True) | 478 | basepath.mkdir(parents=True, exist_ok=True) |
492 | 479 | ||
493 | accelerator = Accelerator( | 480 | accelerator = Accelerator( |
@@ -637,6 +624,7 @@ def main(): | |||
637 | 624 | ||
638 | def collate_fn(examples): | 625 | def collate_fn(examples): |
639 | prompts = [example["prompts"] for example in examples] | 626 | prompts = [example["prompts"] for example in examples] |
627 | cprompts = [example["cprompts"] for example in examples] | ||
640 | nprompts = [example["nprompts"] for example in examples] | 628 | nprompts = [example["nprompts"] for example in examples] |
641 | input_ids = [example["instance_prompt_ids"] for example in examples] | 629 | input_ids = [example["instance_prompt_ids"] for example in examples] |
642 | pixel_values = [example["instance_images"] for example in examples] | 630 | pixel_values = [example["instance_images"] for example in examples] |
@@ -653,6 +641,7 @@ def main(): | |||
653 | 641 | ||
654 | batch = { | 642 | batch = { |
655 | "prompts": prompts, | 643 | "prompts": prompts, |
644 | "cprompts": cprompts, | ||
656 | "nprompts": nprompts, | 645 | "nprompts": nprompts, |
657 | "input_ids": inputs.input_ids, | 646 | "input_ids": inputs.input_ids, |
658 | "pixel_values": pixel_values, | 647 | "pixel_values": pixel_values, |
@@ -664,8 +653,6 @@ def main(): | |||
664 | data_file=args.train_data_file, | 653 | data_file=args.train_data_file, |
665 | batch_size=args.train_batch_size, | 654 | batch_size=args.train_batch_size, |
666 | prompt_processor=prompt_processor, | 655 | prompt_processor=prompt_processor, |
667 | instance_identifier=instance_identifier, | ||
668 | class_identifier=args.class_identifier, | ||
669 | class_subdir="cls", | 656 | class_subdir="cls", |
670 | num_class_images=args.num_class_images, | 657 | num_class_images=args.num_class_images, |
671 | size=args.resolution, | 658 | size=args.resolution, |
@@ -703,7 +690,7 @@ def main(): | |||
703 | with torch.autocast("cuda"), torch.inference_mode(): | 690 | with torch.autocast("cuda"), torch.inference_mode(): |
704 | for batch in batched_data: | 691 | for batch in batched_data: |
705 | image_name = [item.class_image_path for item in batch] | 692 | image_name = [item.class_image_path for item in batch] |
706 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] | 693 | prompt = [item.cprompt for item in batch] |
707 | nprompt = [item.nprompt for item in batch] | 694 | nprompt = [item.nprompt for item in batch] |
708 | 695 | ||
709 | images = pipeline( | 696 | images = pipeline( |
@@ -812,7 +799,6 @@ def main(): | |||
812 | text_encoder=text_encoder, | 799 | text_encoder=text_encoder, |
813 | scheduler=checkpoint_scheduler, | 800 | scheduler=checkpoint_scheduler, |
814 | output_dir=basepath, | 801 | output_dir=basepath, |
815 | instance_identifier=instance_identifier, | ||
816 | placeholder_token=args.placeholder_token, | 802 | placeholder_token=args.placeholder_token, |
817 | placeholder_token_id=placeholder_token_id, | 803 | placeholder_token_id=placeholder_token_id, |
818 | sample_image_size=args.sample_image_size, | 804 | sample_image_size=args.sample_image_size, |