summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-12 08:05:06 +0100
committerVolpeon <git@volpeon.ink>2022-12-12 08:05:06 +0100
commitdd02ace41f69541044e9db106feaa76bf02da8f6 (patch)
tree8f6a8735acac9ebcf7396a40c632fa81c936701a /textual_inversion.py
parentRemove embedding checkpoints from Dreambooth training (diff)
downloadtextual-inversion-diff-dd02ace41f69541044e9db106feaa76bf02da8f6.tar.gz
textual-inversion-diff-dd02ace41f69541044e9db106feaa76bf02da8f6.tar.bz2
textual-inversion-diff-dd02ace41f69541044e9db106feaa76bf02da8f6.zip
Dreambooth: Support loading Textual Inversion embeddings
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py37
1 files changed, 23 insertions, 14 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index da7c747..a9c3326 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -107,7 +107,7 @@ def parse_args():
107 parser.add_argument( 107 parser.add_argument(
108 "--resolution", 108 "--resolution",
109 type=int, 109 type=int,
110 default=512, 110 default=768,
111 help=( 111 help=(
112 "The resolution for input images, all the images in the train/validation dataset will be resized to this" 112 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
113 " resolution" 113 " resolution"
@@ -119,6 +119,12 @@ def parse_args():
119 help="Whether to center crop images before resizing to resolution" 119 help="Whether to center crop images before resizing to resolution"
120 ) 120 )
121 parser.add_argument( 121 parser.add_argument(
122 "--tag_dropout",
123 type=float,
124 default=0.1,
125 help="Tag dropout probability.",
126 )
127 parser.add_argument(
122 "--dataloader_num_workers", 128 "--dataloader_num_workers",
123 type=int, 129 type=int,
124 default=0, 130 default=0,
@@ -171,9 +177,9 @@ def parse_args():
171 ), 177 ),
172 ) 178 )
173 parser.add_argument( 179 parser.add_argument(
174 "--lr_warmup_steps", 180 "--lr_warmup_epochs",
175 type=int, 181 type=int,
176 default=300, 182 default=10,
177 help="Number of steps for the warmup in the lr scheduler." 183 help="Number of steps for the warmup in the lr scheduler."
178 ) 184 )
179 parser.add_argument( 185 parser.add_argument(
@@ -237,7 +243,7 @@ def parse_args():
237 parser.add_argument( 243 parser.add_argument(
238 "--sample_image_size", 244 "--sample_image_size",
239 type=int, 245 type=int,
240 default=512, 246 default=768,
241 help="Size of sample images", 247 help="Size of sample images",
242 ) 248 )
243 parser.add_argument( 249 parser.add_argument(
@@ -267,7 +273,7 @@ def parse_args():
267 parser.add_argument( 273 parser.add_argument(
268 "--sample_steps", 274 "--sample_steps",
269 type=int, 275 type=int,
270 default=30, 276 default=15,
271 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 277 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
272 ) 278 )
273 parser.add_argument( 279 parser.add_argument(
@@ -399,28 +405,28 @@ class Checkpointer:
399 checkpoints_path = self.output_dir.joinpath("checkpoints") 405 checkpoints_path = self.output_dir.joinpath("checkpoints")
400 checkpoints_path.mkdir(parents=True, exist_ok=True) 406 checkpoints_path.mkdir(parents=True, exist_ok=True)
401 407
402 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 408 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
403 409
404 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 410 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
405 # Save a checkpoint 411 # Save a checkpoint
406 learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] 412 learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id]
407 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 413 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
408 414
409 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 415 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
410 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) 416 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
411 417
412 del unwrapped 418 del text_encoder
413 del learned_embeds 419 del learned_embeds
414 420
415 @torch.no_grad() 421 @torch.no_grad()
416 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 422 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
417 samples_path = Path(self.output_dir).joinpath("samples") 423 samples_path = Path(self.output_dir).joinpath("samples")
418 424
419 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 425 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
420 426
421 # Save a sample image 427 # Save a sample image
422 pipeline = VlpnStableDiffusion( 428 pipeline = VlpnStableDiffusion(
423 text_encoder=unwrapped, 429 text_encoder=text_encoder,
424 vae=self.vae, 430 vae=self.vae,
425 unet=self.unet, 431 unet=self.unet,
426 tokenizer=self.tokenizer, 432 tokenizer=self.tokenizer,
@@ -471,7 +477,7 @@ class Checkpointer:
471 negative_prompt=nprompt, 477 negative_prompt=nprompt,
472 height=self.sample_image_size, 478 height=self.sample_image_size,
473 width=self.sample_image_size, 479 width=self.sample_image_size,
474 latents_or_image=latents[:len(prompt)] if latents is not None else None, 480 image=latents[:len(prompt)] if latents is not None else None,
475 generator=generator if latents is not None else None, 481 generator=generator if latents is not None else None,
476 guidance_scale=guidance_scale, 482 guidance_scale=guidance_scale,
477 eta=eta, 483 eta=eta,
@@ -489,7 +495,7 @@ class Checkpointer:
489 del all_samples 495 del all_samples
490 del image_grid 496 del image_grid
491 497
492 del unwrapped 498 del text_encoder
493 del pipeline 499 del pipeline
494 del generator 500 del generator
495 del stable_latents 501 del stable_latents
@@ -662,6 +668,7 @@ def main():
662 num_class_images=args.num_class_images, 668 num_class_images=args.num_class_images,
663 size=args.resolution, 669 size=args.resolution,
664 repeats=args.repeats, 670 repeats=args.repeats,
671 dropout=args.tag_dropout,
665 center_crop=args.center_crop, 672 center_crop=args.center_crop,
666 valid_set_size=args.valid_set_size, 673 valid_set_size=args.valid_set_size,
667 num_workers=args.dataloader_num_workers, 674 num_workers=args.dataloader_num_workers,
@@ -720,6 +727,8 @@ def main():
720 overrode_max_train_steps = True 727 overrode_max_train_steps = True
721 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 728 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
722 729
730 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
731
723 if args.lr_scheduler == "one_cycle": 732 if args.lr_scheduler == "one_cycle":
724 lr_scheduler = get_one_cycle_schedule( 733 lr_scheduler = get_one_cycle_schedule(
725 optimizer=optimizer, 734 optimizer=optimizer,
@@ -728,7 +737,7 @@ def main():
728 elif args.lr_scheduler == "cosine_with_restarts": 737 elif args.lr_scheduler == "cosine_with_restarts":
729 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 738 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
730 optimizer=optimizer, 739 optimizer=optimizer,
731 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 740 num_warmup_steps=warmup_steps,
732 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 741 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
733 num_cycles=args.lr_cycles or math.ceil(math.sqrt( 742 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
734 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), 743 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
@@ -737,7 +746,7 @@ def main():
737 lr_scheduler = get_scheduler( 746 lr_scheduler = get_scheduler(
738 args.lr_scheduler, 747 args.lr_scheduler,
739 optimizer=optimizer, 748 optimizer=optimizer,
740 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 749 num_warmup_steps=warmup_steps,
741 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 750 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
742 ) 751 )
743 752