diff options
author | Volpeon <git@volpeon.ink> | 2022-12-12 08:05:06 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-12 08:05:06 +0100 |
commit | dd02ace41f69541044e9db106feaa76bf02da8f6 (patch) | |
tree | 8f6a8735acac9ebcf7396a40c632fa81c936701a | |
parent | Remove embedding checkpoints from Dreambooth training (diff) | |
download | textual-inversion-diff-dd02ace41f69541044e9db106feaa76bf02da8f6.tar.gz textual-inversion-diff-dd02ace41f69541044e9db106feaa76bf02da8f6.tar.bz2 textual-inversion-diff-dd02ace41f69541044e9db106feaa76bf02da8f6.zip |
Dreambooth: Support loading Textual Inversion embeddings
-rw-r--r-- | dreambooth.py | 36 | ||||
-rw-r--r-- | infer.py | 2 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 13 | ||||
-rw-r--r-- | textual_inversion.py | 37 |
4 files changed, 57 insertions, 31 deletions
diff --git a/dreambooth.py b/dreambooth.py index 675320b..3110c6d 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -118,6 +118,12 @@ def parse_args(): | |||
118 | help="The output directory where the model predictions and checkpoints will be written.", | 118 | help="The output directory where the model predictions and checkpoints will be written.", |
119 | ) | 119 | ) |
120 | parser.add_argument( | 120 | parser.add_argument( |
121 | "--embeddings_dir", | ||
122 | type=str, | ||
123 | default="embeddings_ti", | ||
124 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
125 | ) | ||
126 | parser.add_argument( | ||
121 | "--seed", | 127 | "--seed", |
122 | type=int, | 128 | type=int, |
123 | default=None, | 129 | default=None, |
@@ -521,7 +527,7 @@ class Checkpointer: | |||
521 | negative_prompt=nprompt, | 527 | negative_prompt=nprompt, |
522 | height=self.sample_image_size, | 528 | height=self.sample_image_size, |
523 | width=self.sample_image_size, | 529 | width=self.sample_image_size, |
524 | latents_or_image=latents[:len(prompt)] if latents is not None else None, | 530 | image=latents[:len(prompt)] if latents is not None else None, |
525 | generator=generator if latents is not None else None, | 531 | generator=generator if latents is not None else None, |
526 | guidance_scale=guidance_scale, | 532 | guidance_scale=guidance_scale, |
527 | eta=eta, | 533 | eta=eta, |
@@ -567,6 +573,8 @@ def main(): | |||
567 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 573 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
568 | basepath.mkdir(parents=True, exist_ok=True) | 574 | basepath.mkdir(parents=True, exist_ok=True) |
569 | 575 | ||
576 | embeddings_dir = Path(args.embeddings_dir) | ||
577 | |||
570 | accelerator = Accelerator( | 578 | accelerator = Accelerator( |
571 | log_with=LoggerType.TENSORBOARD, | 579 | log_with=LoggerType.TENSORBOARD, |
572 | logging_dir=f"{basepath}", | 580 | logging_dir=f"{basepath}", |
@@ -630,15 +638,25 @@ def main(): | |||
630 | 638 | ||
631 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 639 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
632 | 640 | ||
641 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | ||
642 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
643 | |||
644 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
645 | |||
633 | print(f"Token ID mappings:") | 646 | print(f"Token ID mappings:") |
634 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 647 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
635 | print(f"- {token_id} {token}") | 648 | print(f"- {token_id} {token}") |
636 | 649 | ||
637 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 650 | embedding_file = embeddings_dir.joinpath(f"{token}.bin") |
638 | text_encoder.resize_token_embeddings(len(tokenizer)) | 651 | if embedding_file.exists() and embedding_file.is_file(): |
652 | embedding_data = torch.load(embedding_file, map_location="cpu") | ||
653 | |||
654 | emb = next(iter(embedding_data.values())) | ||
655 | if len(emb.shape) == 1: | ||
656 | emb = emb.unsqueeze(0) | ||
657 | |||
658 | token_embeds[token_id] = emb | ||
639 | 659 | ||
640 | # Initialise the newly added placeholder token with the embeddings of the initializer token | ||
641 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
642 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 660 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
643 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 661 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
644 | 662 | ||
@@ -959,8 +977,6 @@ def main(): | |||
959 | else: | 977 | else: |
960 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 978 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
961 | 979 | ||
962 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
963 | |||
964 | if args.num_class_images != 0: | 980 | if args.num_class_images != 0: |
965 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 981 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
966 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 982 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
@@ -977,6 +993,8 @@ def main(): | |||
977 | else: | 993 | else: |
978 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 994 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
979 | 995 | ||
996 | acc = (model_pred == latents).float().mean() | ||
997 | |||
980 | accelerator.backward(loss) | 998 | accelerator.backward(loss) |
981 | 999 | ||
982 | if not args.train_text_encoder: | 1000 | if not args.train_text_encoder: |
@@ -1004,8 +1022,6 @@ def main(): | |||
1004 | ema_unet.step(unet) | 1022 | ema_unet.step(unet) |
1005 | optimizer.zero_grad(set_to_none=True) | 1023 | optimizer.zero_grad(set_to_none=True) |
1006 | 1024 | ||
1007 | acc = (model_pred == latents).float().mean() | ||
1008 | |||
1009 | avg_loss.update(loss.detach_(), bsz) | 1025 | avg_loss.update(loss.detach_(), bsz) |
1010 | avg_acc.update(acc.detach_(), bsz) | 1026 | avg_acc.update(acc.detach_(), bsz) |
1011 | 1027 | ||
@@ -1069,8 +1085,6 @@ def main(): | |||
1069 | else: | 1085 | else: |
1070 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 1086 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
1071 | 1087 | ||
1072 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
1073 | |||
1074 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 1088 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
1075 | 1089 | ||
1076 | acc = (model_pred == latents).float().mean() | 1090 | acc = (model_pred == latents).float().mean() |
@@ -291,7 +291,7 @@ def generate(output_dir, pipeline, args): | |||
291 | num_inference_steps=args.steps, | 291 | num_inference_steps=args.steps, |
292 | guidance_scale=args.guidance_scale, | 292 | guidance_scale=args.guidance_scale, |
293 | generator=generator, | 293 | generator=generator, |
294 | latents_or_image=init_image, | 294 | image=init_image, |
295 | strength=args.image_noise, | 295 | strength=args.image_noise, |
296 | ).images | 296 | ).images |
297 | 297 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 78a34d5..141b9a7 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -27,7 +27,9 @@ from models.clip.prompt import PromptProcessor | |||
27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
28 | 28 | ||
29 | 29 | ||
30 | def preprocess(image, w, h): | 30 | def preprocess(image): |
31 | w, h = image.size | ||
32 | w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 | ||
31 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) | 33 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) |
32 | image = np.array(image).astype(np.float32) / 255.0 | 34 | image = np.array(image).astype(np.float32) / 255.0 |
33 | image = image[None].transpose(0, 3, 1, 2) | 35 | image = image[None].transpose(0, 3, 1, 2) |
@@ -310,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
310 | guidance_scale: Optional[float] = 7.5, | 312 | guidance_scale: Optional[float] = 7.5, |
311 | eta: Optional[float] = 0.0, | 313 | eta: Optional[float] = 0.0, |
312 | generator: Optional[torch.Generator] = None, | 314 | generator: Optional[torch.Generator] = None, |
313 | latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 315 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
314 | output_type: Optional[str] = "pil", | 316 | output_type: Optional[str] = "pil", |
315 | return_dict: bool = True, | 317 | return_dict: bool = True, |
316 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 318 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
@@ -373,7 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
373 | batch_size = len(prompt) | 375 | batch_size = len(prompt) |
374 | device = self.execution_device | 376 | device = self.execution_device |
375 | do_classifier_free_guidance = guidance_scale > 1.0 | 377 | do_classifier_free_guidance = guidance_scale > 1.0 |
376 | latents_are_image = isinstance(latents_or_image, PIL.Image.Image) | 378 | latents_are_image = isinstance(image, PIL.Image.Image) |
377 | 379 | ||
378 | # 3. Encode input prompt | 380 | # 3. Encode input prompt |
379 | text_embeddings = self.encode_prompt( | 381 | text_embeddings = self.encode_prompt( |
@@ -391,9 +393,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
391 | # 5. Prepare latent variables | 393 | # 5. Prepare latent variables |
392 | num_channels_latents = self.unet.in_channels | 394 | num_channels_latents = self.unet.in_channels |
393 | if latents_are_image: | 395 | if latents_are_image: |
396 | image = preprocess(image) | ||
394 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 397 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
395 | latents = self.prepare_latents_from_image( | 398 | latents = self.prepare_latents_from_image( |
396 | latents_or_image, | 399 | image, |
397 | latent_timestep, | 400 | latent_timestep, |
398 | batch_size, | 401 | batch_size, |
399 | num_images_per_prompt, | 402 | num_images_per_prompt, |
@@ -411,7 +414,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
411 | text_embeddings.dtype, | 414 | text_embeddings.dtype, |
412 | device, | 415 | device, |
413 | generator, | 416 | generator, |
414 | latents_or_image, | 417 | image, |
415 | ) | 418 | ) |
416 | 419 | ||
417 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 420 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
diff --git a/textual_inversion.py b/textual_inversion.py index da7c747..a9c3326 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -107,7 +107,7 @@ def parse_args(): | |||
107 | parser.add_argument( | 107 | parser.add_argument( |
108 | "--resolution", | 108 | "--resolution", |
109 | type=int, | 109 | type=int, |
110 | default=512, | 110 | default=768, |
111 | help=( | 111 | help=( |
112 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | 112 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" |
113 | " resolution" | 113 | " resolution" |
@@ -119,6 +119,12 @@ def parse_args(): | |||
119 | help="Whether to center crop images before resizing to resolution" | 119 | help="Whether to center crop images before resizing to resolution" |
120 | ) | 120 | ) |
121 | parser.add_argument( | 121 | parser.add_argument( |
122 | "--tag_dropout", | ||
123 | type=float, | ||
124 | default=0.1, | ||
125 | help="Tag dropout probability.", | ||
126 | ) | ||
127 | parser.add_argument( | ||
122 | "--dataloader_num_workers", | 128 | "--dataloader_num_workers", |
123 | type=int, | 129 | type=int, |
124 | default=0, | 130 | default=0, |
@@ -171,9 +177,9 @@ def parse_args(): | |||
171 | ), | 177 | ), |
172 | ) | 178 | ) |
173 | parser.add_argument( | 179 | parser.add_argument( |
174 | "--lr_warmup_steps", | 180 | "--lr_warmup_epochs", |
175 | type=int, | 181 | type=int, |
176 | default=300, | 182 | default=10, |
177 | help="Number of steps for the warmup in the lr scheduler." | 183 | help="Number of steps for the warmup in the lr scheduler." |
178 | ) | 184 | ) |
179 | parser.add_argument( | 185 | parser.add_argument( |
@@ -237,7 +243,7 @@ def parse_args(): | |||
237 | parser.add_argument( | 243 | parser.add_argument( |
238 | "--sample_image_size", | 244 | "--sample_image_size", |
239 | type=int, | 245 | type=int, |
240 | default=512, | 246 | default=768, |
241 | help="Size of sample images", | 247 | help="Size of sample images", |
242 | ) | 248 | ) |
243 | parser.add_argument( | 249 | parser.add_argument( |
@@ -267,7 +273,7 @@ def parse_args(): | |||
267 | parser.add_argument( | 273 | parser.add_argument( |
268 | "--sample_steps", | 274 | "--sample_steps", |
269 | type=int, | 275 | type=int, |
270 | default=30, | 276 | default=15, |
271 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 277 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
272 | ) | 278 | ) |
273 | parser.add_argument( | 279 | parser.add_argument( |
@@ -399,28 +405,28 @@ class Checkpointer: | |||
399 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 405 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
400 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 406 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
401 | 407 | ||
402 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 408 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
403 | 409 | ||
404 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 410 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
405 | # Save a checkpoint | 411 | # Save a checkpoint |
406 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] | 412 | learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] |
407 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 413 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
408 | 414 | ||
409 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 415 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
410 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 416 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
411 | 417 | ||
412 | del unwrapped | 418 | del text_encoder |
413 | del learned_embeds | 419 | del learned_embeds |
414 | 420 | ||
415 | @torch.no_grad() | 421 | @torch.no_grad() |
416 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 422 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
417 | samples_path = Path(self.output_dir).joinpath("samples") | 423 | samples_path = Path(self.output_dir).joinpath("samples") |
418 | 424 | ||
419 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 425 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
420 | 426 | ||
421 | # Save a sample image | 427 | # Save a sample image |
422 | pipeline = VlpnStableDiffusion( | 428 | pipeline = VlpnStableDiffusion( |
423 | text_encoder=unwrapped, | 429 | text_encoder=text_encoder, |
424 | vae=self.vae, | 430 | vae=self.vae, |
425 | unet=self.unet, | 431 | unet=self.unet, |
426 | tokenizer=self.tokenizer, | 432 | tokenizer=self.tokenizer, |
@@ -471,7 +477,7 @@ class Checkpointer: | |||
471 | negative_prompt=nprompt, | 477 | negative_prompt=nprompt, |
472 | height=self.sample_image_size, | 478 | height=self.sample_image_size, |
473 | width=self.sample_image_size, | 479 | width=self.sample_image_size, |
474 | latents_or_image=latents[:len(prompt)] if latents is not None else None, | 480 | image=latents[:len(prompt)] if latents is not None else None, |
475 | generator=generator if latents is not None else None, | 481 | generator=generator if latents is not None else None, |
476 | guidance_scale=guidance_scale, | 482 | guidance_scale=guidance_scale, |
477 | eta=eta, | 483 | eta=eta, |
@@ -489,7 +495,7 @@ class Checkpointer: | |||
489 | del all_samples | 495 | del all_samples |
490 | del image_grid | 496 | del image_grid |
491 | 497 | ||
492 | del unwrapped | 498 | del text_encoder |
493 | del pipeline | 499 | del pipeline |
494 | del generator | 500 | del generator |
495 | del stable_latents | 501 | del stable_latents |
@@ -662,6 +668,7 @@ def main(): | |||
662 | num_class_images=args.num_class_images, | 668 | num_class_images=args.num_class_images, |
663 | size=args.resolution, | 669 | size=args.resolution, |
664 | repeats=args.repeats, | 670 | repeats=args.repeats, |
671 | dropout=args.tag_dropout, | ||
665 | center_crop=args.center_crop, | 672 | center_crop=args.center_crop, |
666 | valid_set_size=args.valid_set_size, | 673 | valid_set_size=args.valid_set_size, |
667 | num_workers=args.dataloader_num_workers, | 674 | num_workers=args.dataloader_num_workers, |
@@ -720,6 +727,8 @@ def main(): | |||
720 | overrode_max_train_steps = True | 727 | overrode_max_train_steps = True |
721 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 728 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
722 | 729 | ||
730 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | ||
731 | |||
723 | if args.lr_scheduler == "one_cycle": | 732 | if args.lr_scheduler == "one_cycle": |
724 | lr_scheduler = get_one_cycle_schedule( | 733 | lr_scheduler = get_one_cycle_schedule( |
725 | optimizer=optimizer, | 734 | optimizer=optimizer, |
@@ -728,7 +737,7 @@ def main(): | |||
728 | elif args.lr_scheduler == "cosine_with_restarts": | 737 | elif args.lr_scheduler == "cosine_with_restarts": |
729 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 738 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
730 | optimizer=optimizer, | 739 | optimizer=optimizer, |
731 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 740 | num_warmup_steps=warmup_steps, |
732 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 741 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
733 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | 742 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
734 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), | 743 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
@@ -737,7 +746,7 @@ def main(): | |||
737 | lr_scheduler = get_scheduler( | 746 | lr_scheduler = get_scheduler( |
738 | args.lr_scheduler, | 747 | args.lr_scheduler, |
739 | optimizer=optimizer, | 748 | optimizer=optimizer, |
740 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 749 | num_warmup_steps=warmup_steps, |
741 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 750 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
742 | ) | 751 | ) |
743 | 752 | ||