summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-19 17:45:56 +0200
committerVolpeon <git@volpeon.ink>2022-10-19 17:45:56 +0200
commit5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa (patch)
treefdf3c57c2f238e5381ad5e961e704fedd3e6816a /dreambooth.py
parentUpdated Dreambooth training (diff)
downloadtextual-inversion-diff-5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa.tar.gz
textual-inversion-diff-5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa.tar.bz2
textual-inversion-diff-5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa.zip
Dreambooth: Added option to insert a new input token; removed Dreambooth Plus
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py48
1 files changed, 42 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py
index d1cf535..da8399f 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -69,6 +69,18 @@ def parse_args():
69 help="A token to use as a placeholder for the concept.", 69 help="A token to use as a placeholder for the concept.",
70 ) 70 )
71 parser.add_argument( 71 parser.add_argument(
72 "--placeholder_token",
73 type=str,
74 default="<*>",
75 help="A token to use as a placeholder for the concept.",
76 )
77 parser.add_argument(
78 "--initializer_token",
79 type=str,
80 default=None,
81 help="A token to use as initializer word."
82 )
83 parser.add_argument(
72 "--num_class_images", 84 "--num_class_images",
73 type=int, 85 type=int,
74 default=400, 86 default=400,
@@ -114,7 +126,7 @@ def parse_args():
114 parser.add_argument( 126 parser.add_argument(
115 "--max_train_steps", 127 "--max_train_steps",
116 type=int, 128 type=int,
117 default=3600, 129 default=6000,
118 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 130 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
119 ) 131 )
120 parser.add_argument( 132 parser.add_argument(
@@ -131,13 +143,13 @@ def parse_args():
131 parser.add_argument( 143 parser.add_argument(
132 "--learning_rate_unet", 144 "--learning_rate_unet",
133 type=float, 145 type=float,
134 default=3e-6, 146 default=2e-6,
135 help="Initial learning rate (after the potential warmup period) to use.", 147 help="Initial learning rate (after the potential warmup period) to use.",
136 ) 148 )
137 parser.add_argument( 149 parser.add_argument(
138 "--learning_rate_text", 150 "--learning_rate_text",
139 type=float, 151 type=float,
140 default=3e-6, 152 default=2e-6,
141 help="Initial learning rate (after the potential warmup period) to use.", 153 help="Initial learning rate (after the potential warmup period) to use.",
142 ) 154 )
143 parser.add_argument( 155 parser.add_argument(
@@ -476,8 +488,13 @@ class Checkpointer:
476def main(): 488def main():
477 args = parse_args() 489 args = parse_args()
478 490
491 instance_identifier = args.instance_identifier
492
493 if args.placeholder_token is not None:
494 instance_identifier = instance_identifier.format(args.placeholder_token)
495
479 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 496 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
480 basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now) 497 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
481 basepath.mkdir(parents=True, exist_ok=True) 498 basepath.mkdir(parents=True, exist_ok=True)
482 499
483 accelerator = Accelerator( 500 accelerator = Accelerator(
@@ -514,6 +531,25 @@ def main():
514 device=accelerator.device 531 device=accelerator.device
515 ) if args.use_ema else None 532 ) if args.use_ema else None
516 533
534 if args.initializer_token is not None:
535 # Convert the initializer_token, placeholder_token to ids
536 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
537 print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.")
538 initializer_token_ids = torch.tensor(initializer_token_ids[:1])
539
540 # Add the placeholder token in tokenizer
541 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
542 if num_added_tokens == 0:
543 print(f"Re-using existing token {args.placeholder_token}.")
544 else:
545 print(f"Training new token {args.placeholder_token}.")
546 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
547
548 text_encoder.resize_token_embeddings(len(tokenizer))
549 token_embeds = text_encoder.get_input_embeddings()
550 initializer_token_embeddings = token_embeds(initializer_token_ids)
551 token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings
552
517 prompt_processor = PromptProcessor(tokenizer, text_encoder) 553 prompt_processor = PromptProcessor(tokenizer, text_encoder)
518 554
519 if args.gradient_checkpointing: 555 if args.gradient_checkpointing:
@@ -605,7 +641,7 @@ def main():
605 data_file=args.train_data_file, 641 data_file=args.train_data_file,
606 batch_size=args.train_batch_size, 642 batch_size=args.train_batch_size,
607 prompt_processor=prompt_processor, 643 prompt_processor=prompt_processor,
608 instance_identifier=args.instance_identifier, 644 instance_identifier=instance_identifier,
609 class_identifier=args.class_identifier, 645 class_identifier=args.class_identifier,
610 class_subdir="cls", 646 class_subdir="cls",
611 num_class_images=args.num_class_images, 647 num_class_images=args.num_class_images,
@@ -735,7 +771,7 @@ def main():
735 tokenizer=tokenizer, 771 tokenizer=tokenizer,
736 text_encoder=text_encoder, 772 text_encoder=text_encoder,
737 output_dir=basepath, 773 output_dir=basepath,
738 instance_identifier=args.instance_identifier, 774 instance_identifier=instance_identifier,
739 sample_image_size=args.sample_image_size, 775 sample_image_size=args.sample_image_size,
740 sample_batch_size=args.sample_batch_size, 776 sample_batch_size=args.sample_batch_size,
741 sample_batches=args.sample_batches, 777 sample_batches=args.sample_batches,