summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
commit1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch)
tree42d3302610046dbc5d39d254f7a2d5d5f601aa18 /train_dreambooth.py
parentFix (diff)
downloadtextual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip
Improved class prompt handling
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py30
1 files changed, 8 insertions, 22 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 9749c62..ff67d12 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -62,16 +62,10 @@ def parse_args():
62 default="template", 62 default="template",
63 ) 63 )
64 parser.add_argument( 64 parser.add_argument(
65 "--instance_identifier", 65 "--project",
66 type=str, 66 type=str,
67 default=None, 67 default=None,
68 help="A token to use as a placeholder for the concept.", 68 help="The name of the current project.",
69 )
70 parser.add_argument(
71 "--class_identifier",
72 type=str,
73 default=None,
74 help="A token to use as a placeholder for the concept.",
75 ) 69 )
76 parser.add_argument( 70 parser.add_argument(
77 "--placeholder_token", 71 "--placeholder_token",
@@ -364,8 +358,8 @@ def parse_args():
364 if args.pretrained_model_name_or_path is None: 358 if args.pretrained_model_name_or_path is None:
365 raise ValueError("You must specify --pretrained_model_name_or_path") 359 raise ValueError("You must specify --pretrained_model_name_or_path")
366 360
367 if args.instance_identifier is None: 361 if args.project is None:
368 raise ValueError("You must specify --instance_identifier") 362 raise ValueError("You must specify --project")
369 363
370 if isinstance(args.initializer_token, str): 364 if isinstance(args.initializer_token, str):
371 args.initializer_token = [args.initializer_token] 365 args.initializer_token = [args.initializer_token]
@@ -397,7 +391,6 @@ class Checkpointer(CheckpointerBase):
397 text_encoder, 391 text_encoder,
398 scheduler, 392 scheduler,
399 output_dir: Path, 393 output_dir: Path,
400 instance_identifier,
401 placeholder_token, 394 placeholder_token,
402 placeholder_token_id, 395 placeholder_token_id,
403 sample_image_size, 396 sample_image_size,
@@ -408,7 +401,6 @@ class Checkpointer(CheckpointerBase):
408 super().__init__( 401 super().__init__(
409 datamodule=datamodule, 402 datamodule=datamodule,
410 output_dir=output_dir, 403 output_dir=output_dir,
411 instance_identifier=instance_identifier,
412 placeholder_token=placeholder_token, 404 placeholder_token=placeholder_token,
413 placeholder_token_id=placeholder_token_id, 405 placeholder_token_id=placeholder_token_id,
414 sample_image_size=sample_image_size, 406 sample_image_size=sample_image_size,
@@ -481,13 +473,8 @@ def main():
481 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 473 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
482 ) 474 )
483 475
484 instance_identifier = args.instance_identifier
485
486 if len(args.placeholder_token) != 0:
487 instance_identifier = instance_identifier.format(args.placeholder_token[0])
488
489 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 476 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
490 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 477 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
491 basepath.mkdir(parents=True, exist_ok=True) 478 basepath.mkdir(parents=True, exist_ok=True)
492 479
493 accelerator = Accelerator( 480 accelerator = Accelerator(
@@ -637,6 +624,7 @@ def main():
637 624
638 def collate_fn(examples): 625 def collate_fn(examples):
639 prompts = [example["prompts"] for example in examples] 626 prompts = [example["prompts"] for example in examples]
627 cprompts = [example["cprompts"] for example in examples]
640 nprompts = [example["nprompts"] for example in examples] 628 nprompts = [example["nprompts"] for example in examples]
641 input_ids = [example["instance_prompt_ids"] for example in examples] 629 input_ids = [example["instance_prompt_ids"] for example in examples]
642 pixel_values = [example["instance_images"] for example in examples] 630 pixel_values = [example["instance_images"] for example in examples]
@@ -653,6 +641,7 @@ def main():
653 641
654 batch = { 642 batch = {
655 "prompts": prompts, 643 "prompts": prompts,
644 "cprompts": cprompts,
656 "nprompts": nprompts, 645 "nprompts": nprompts,
657 "input_ids": inputs.input_ids, 646 "input_ids": inputs.input_ids,
658 "pixel_values": pixel_values, 647 "pixel_values": pixel_values,
@@ -664,8 +653,6 @@ def main():
664 data_file=args.train_data_file, 653 data_file=args.train_data_file,
665 batch_size=args.train_batch_size, 654 batch_size=args.train_batch_size,
666 prompt_processor=prompt_processor, 655 prompt_processor=prompt_processor,
667 instance_identifier=instance_identifier,
668 class_identifier=args.class_identifier,
669 class_subdir="cls", 656 class_subdir="cls",
670 num_class_images=args.num_class_images, 657 num_class_images=args.num_class_images,
671 size=args.resolution, 658 size=args.resolution,
@@ -703,7 +690,7 @@ def main():
703 with torch.autocast("cuda"), torch.inference_mode(): 690 with torch.autocast("cuda"), torch.inference_mode():
704 for batch in batched_data: 691 for batch in batched_data:
705 image_name = [item.class_image_path for item in batch] 692 image_name = [item.class_image_path for item in batch]
706 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] 693 prompt = [item.cprompt for item in batch]
707 nprompt = [item.nprompt for item in batch] 694 nprompt = [item.nprompt for item in batch]
708 695
709 images = pipeline( 696 images = pipeline(
@@ -812,7 +799,6 @@ def main():
812 text_encoder=text_encoder, 799 text_encoder=text_encoder,
813 scheduler=checkpoint_scheduler, 800 scheduler=checkpoint_scheduler,
814 output_dir=basepath, 801 output_dir=basepath,
815 instance_identifier=instance_identifier,
816 placeholder_token=args.placeholder_token, 802 placeholder_token=args.placeholder_token,
817 placeholder_token_id=placeholder_token_id, 803 placeholder_token_id=placeholder_token_id,
818 sample_image_size=args.sample_image_size, 804 sample_image_size=args.sample_image_size,