summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py36
-rw-r--r--infer.py2
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py13
-rw-r--r--textual_inversion.py37
4 files changed, 57 insertions, 31 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 675320b..3110c6d 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -118,6 +118,12 @@ def parse_args():
118 help="The output directory where the model predictions and checkpoints will be written.", 118 help="The output directory where the model predictions and checkpoints will be written.",
119 ) 119 )
120 parser.add_argument( 120 parser.add_argument(
121 "--embeddings_dir",
122 type=str,
123 default="embeddings_ti",
124 help="The embeddings directory where Textual Inversion embeddings are stored.",
125 )
126 parser.add_argument(
121 "--seed", 127 "--seed",
122 type=int, 128 type=int,
123 default=None, 129 default=None,
@@ -521,7 +527,7 @@ class Checkpointer:
521 negative_prompt=nprompt, 527 negative_prompt=nprompt,
522 height=self.sample_image_size, 528 height=self.sample_image_size,
523 width=self.sample_image_size, 529 width=self.sample_image_size,
524 latents_or_image=latents[:len(prompt)] if latents is not None else None, 530 image=latents[:len(prompt)] if latents is not None else None,
525 generator=generator if latents is not None else None, 531 generator=generator if latents is not None else None,
526 guidance_scale=guidance_scale, 532 guidance_scale=guidance_scale,
527 eta=eta, 533 eta=eta,
@@ -567,6 +573,8 @@ def main():
567 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 573 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
568 basepath.mkdir(parents=True, exist_ok=True) 574 basepath.mkdir(parents=True, exist_ok=True)
569 575
576 embeddings_dir = Path(args.embeddings_dir)
577
570 accelerator = Accelerator( 578 accelerator = Accelerator(
571 log_with=LoggerType.TENSORBOARD, 579 log_with=LoggerType.TENSORBOARD,
572 logging_dir=f"{basepath}", 580 logging_dir=f"{basepath}",
@@ -630,15 +638,25 @@ def main():
630 638
631 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 639 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
632 640
641 # Resize the token embeddings as we are adding new special tokens to the tokenizer
642 text_encoder.resize_token_embeddings(len(tokenizer))
643
644 token_embeds = text_encoder.get_input_embeddings().weight.data
645
633 print(f"Token ID mappings:") 646 print(f"Token ID mappings:")
634 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): 647 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
635 print(f"- {token_id} {token}") 648 print(f"- {token_id} {token}")
636 649
637 # Resize the token embeddings as we are adding new special tokens to the tokenizer 650 embedding_file = embeddings_dir.joinpath(f"{token}.bin")
638 text_encoder.resize_token_embeddings(len(tokenizer)) 651 if embedding_file.exists() and embedding_file.is_file():
652 embedding_data = torch.load(embedding_file, map_location="cpu")
653
654 emb = next(iter(embedding_data.values()))
655 if len(emb.shape) == 1:
656 emb = emb.unsqueeze(0)
657
658 token_embeds[token_id] = emb
639 659
640 # Initialise the newly added placeholder token with the embeddings of the initializer token
641 token_embeds = text_encoder.get_input_embeddings().weight.data
642 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 660 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
643 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 661 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
644 662
@@ -959,8 +977,6 @@ def main():
959 else: 977 else:
960 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 978 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
961 979
962 del timesteps, noise, latents, noisy_latents, encoder_hidden_states
963
964 if args.num_class_images != 0: 980 if args.num_class_images != 0:
965 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 981 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
966 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 982 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
@@ -977,6 +993,8 @@ def main():
977 else: 993 else:
978 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 994 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
979 995
996 acc = (model_pred == latents).float().mean()
997
980 accelerator.backward(loss) 998 accelerator.backward(loss)
981 999
982 if not args.train_text_encoder: 1000 if not args.train_text_encoder:
@@ -1004,8 +1022,6 @@ def main():
1004 ema_unet.step(unet) 1022 ema_unet.step(unet)
1005 optimizer.zero_grad(set_to_none=True) 1023 optimizer.zero_grad(set_to_none=True)
1006 1024
1007 acc = (model_pred == latents).float().mean()
1008
1009 avg_loss.update(loss.detach_(), bsz) 1025 avg_loss.update(loss.detach_(), bsz)
1010 avg_acc.update(acc.detach_(), bsz) 1026 avg_acc.update(acc.detach_(), bsz)
1011 1027
@@ -1069,8 +1085,6 @@ def main():
1069 else: 1085 else:
1070 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 1086 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1071 1087
1072 del timesteps, noise, latents, noisy_latents, encoder_hidden_states
1073
1074 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1088 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1075 1089
1076 acc = (model_pred == latents).float().mean() 1090 acc = (model_pred == latents).float().mean()
diff --git a/infer.py b/infer.py
index e3fa9e5..5bd926a 100644
--- a/infer.py
+++ b/infer.py
@@ -291,7 +291,7 @@ def generate(output_dir, pipeline, args):
291 num_inference_steps=args.steps, 291 num_inference_steps=args.steps,
292 guidance_scale=args.guidance_scale, 292 guidance_scale=args.guidance_scale,
293 generator=generator, 293 generator=generator,
294 latents_or_image=init_image, 294 image=init_image,
295 strength=args.image_noise, 295 strength=args.image_noise,
296 ).images 296 ).images
297 297
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 78a34d5..141b9a7 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -27,7 +27,9 @@ from models.clip.prompt import PromptProcessor
27logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28 28
29 29
30def preprocess(image, w, h): 30def preprocess(image):
31 w, h = image.size
32 w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
31 image = image.resize((w, h), resample=PIL.Image.LANCZOS) 33 image = image.resize((w, h), resample=PIL.Image.LANCZOS)
32 image = np.array(image).astype(np.float32) / 255.0 34 image = np.array(image).astype(np.float32) / 255.0
33 image = image[None].transpose(0, 3, 1, 2) 35 image = image[None].transpose(0, 3, 1, 2)
@@ -310,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
310 guidance_scale: Optional[float] = 7.5, 312 guidance_scale: Optional[float] = 7.5,
311 eta: Optional[float] = 0.0, 313 eta: Optional[float] = 0.0,
312 generator: Optional[torch.Generator] = None, 314 generator: Optional[torch.Generator] = None,
313 latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 315 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
314 output_type: Optional[str] = "pil", 316 output_type: Optional[str] = "pil",
315 return_dict: bool = True, 317 return_dict: bool = True,
316 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 318 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -373,7 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
373 batch_size = len(prompt) 375 batch_size = len(prompt)
374 device = self.execution_device 376 device = self.execution_device
375 do_classifier_free_guidance = guidance_scale > 1.0 377 do_classifier_free_guidance = guidance_scale > 1.0
376 latents_are_image = isinstance(latents_or_image, PIL.Image.Image) 378 latents_are_image = isinstance(image, PIL.Image.Image)
377 379
378 # 3. Encode input prompt 380 # 3. Encode input prompt
379 text_embeddings = self.encode_prompt( 381 text_embeddings = self.encode_prompt(
@@ -391,9 +393,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
391 # 5. Prepare latent variables 393 # 5. Prepare latent variables
392 num_channels_latents = self.unet.in_channels 394 num_channels_latents = self.unet.in_channels
393 if latents_are_image: 395 if latents_are_image:
396 image = preprocess(image)
394 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 397 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
395 latents = self.prepare_latents_from_image( 398 latents = self.prepare_latents_from_image(
396 latents_or_image, 399 image,
397 latent_timestep, 400 latent_timestep,
398 batch_size, 401 batch_size,
399 num_images_per_prompt, 402 num_images_per_prompt,
@@ -411,7 +414,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
411 text_embeddings.dtype, 414 text_embeddings.dtype,
412 device, 415 device,
413 generator, 416 generator,
414 latents_or_image, 417 image,
415 ) 418 )
416 419
417 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 420 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
diff --git a/textual_inversion.py b/textual_inversion.py
index da7c747..a9c3326 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -107,7 +107,7 @@ def parse_args():
107 parser.add_argument( 107 parser.add_argument(
108 "--resolution", 108 "--resolution",
109 type=int, 109 type=int,
110 default=512, 110 default=768,
111 help=( 111 help=(
112 "The resolution for input images, all the images in the train/validation dataset will be resized to this" 112 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
113 " resolution" 113 " resolution"
@@ -119,6 +119,12 @@ def parse_args():
119 help="Whether to center crop images before resizing to resolution" 119 help="Whether to center crop images before resizing to resolution"
120 ) 120 )
121 parser.add_argument( 121 parser.add_argument(
122 "--tag_dropout",
123 type=float,
124 default=0.1,
125 help="Tag dropout probability.",
126 )
127 parser.add_argument(
122 "--dataloader_num_workers", 128 "--dataloader_num_workers",
123 type=int, 129 type=int,
124 default=0, 130 default=0,
@@ -171,9 +177,9 @@ def parse_args():
171 ), 177 ),
172 ) 178 )
173 parser.add_argument( 179 parser.add_argument(
174 "--lr_warmup_steps", 180 "--lr_warmup_epochs",
175 type=int, 181 type=int,
176 default=300, 182 default=10,
177 help="Number of steps for the warmup in the lr scheduler." 183 help="Number of steps for the warmup in the lr scheduler."
178 ) 184 )
179 parser.add_argument( 185 parser.add_argument(
@@ -237,7 +243,7 @@ def parse_args():
237 parser.add_argument( 243 parser.add_argument(
238 "--sample_image_size", 244 "--sample_image_size",
239 type=int, 245 type=int,
240 default=512, 246 default=768,
241 help="Size of sample images", 247 help="Size of sample images",
242 ) 248 )
243 parser.add_argument( 249 parser.add_argument(
@@ -267,7 +273,7 @@ def parse_args():
267 parser.add_argument( 273 parser.add_argument(
268 "--sample_steps", 274 "--sample_steps",
269 type=int, 275 type=int,
270 default=30, 276 default=15,
271 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 277 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
272 ) 278 )
273 parser.add_argument( 279 parser.add_argument(
@@ -399,28 +405,28 @@ class Checkpointer:
399 checkpoints_path = self.output_dir.joinpath("checkpoints") 405 checkpoints_path = self.output_dir.joinpath("checkpoints")
400 checkpoints_path.mkdir(parents=True, exist_ok=True) 406 checkpoints_path.mkdir(parents=True, exist_ok=True)
401 407
402 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 408 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
403 409
404 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 410 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
405 # Save a checkpoint 411 # Save a checkpoint
406 learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] 412 learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id]
407 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 413 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
408 414
409 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 415 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
410 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) 416 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
411 417
412 del unwrapped 418 del text_encoder
413 del learned_embeds 419 del learned_embeds
414 420
415 @torch.no_grad() 421 @torch.no_grad()
416 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 422 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
417 samples_path = Path(self.output_dir).joinpath("samples") 423 samples_path = Path(self.output_dir).joinpath("samples")
418 424
419 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 425 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
420 426
421 # Save a sample image 427 # Save a sample image
422 pipeline = VlpnStableDiffusion( 428 pipeline = VlpnStableDiffusion(
423 text_encoder=unwrapped, 429 text_encoder=text_encoder,
424 vae=self.vae, 430 vae=self.vae,
425 unet=self.unet, 431 unet=self.unet,
426 tokenizer=self.tokenizer, 432 tokenizer=self.tokenizer,
@@ -471,7 +477,7 @@ class Checkpointer:
471 negative_prompt=nprompt, 477 negative_prompt=nprompt,
472 height=self.sample_image_size, 478 height=self.sample_image_size,
473 width=self.sample_image_size, 479 width=self.sample_image_size,
474 latents_or_image=latents[:len(prompt)] if latents is not None else None, 480 image=latents[:len(prompt)] if latents is not None else None,
475 generator=generator if latents is not None else None, 481 generator=generator if latents is not None else None,
476 guidance_scale=guidance_scale, 482 guidance_scale=guidance_scale,
477 eta=eta, 483 eta=eta,
@@ -489,7 +495,7 @@ class Checkpointer:
489 del all_samples 495 del all_samples
490 del image_grid 496 del image_grid
491 497
492 del unwrapped 498 del text_encoder
493 del pipeline 499 del pipeline
494 del generator 500 del generator
495 del stable_latents 501 del stable_latents
@@ -662,6 +668,7 @@ def main():
662 num_class_images=args.num_class_images, 668 num_class_images=args.num_class_images,
663 size=args.resolution, 669 size=args.resolution,
664 repeats=args.repeats, 670 repeats=args.repeats,
671 dropout=args.tag_dropout,
665 center_crop=args.center_crop, 672 center_crop=args.center_crop,
666 valid_set_size=args.valid_set_size, 673 valid_set_size=args.valid_set_size,
667 num_workers=args.dataloader_num_workers, 674 num_workers=args.dataloader_num_workers,
@@ -720,6 +727,8 @@ def main():
720 overrode_max_train_steps = True 727 overrode_max_train_steps = True
721 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 728 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
722 729
730 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
731
723 if args.lr_scheduler == "one_cycle": 732 if args.lr_scheduler == "one_cycle":
724 lr_scheduler = get_one_cycle_schedule( 733 lr_scheduler = get_one_cycle_schedule(
725 optimizer=optimizer, 734 optimizer=optimizer,
@@ -728,7 +737,7 @@ def main():
728 elif args.lr_scheduler == "cosine_with_restarts": 737 elif args.lr_scheduler == "cosine_with_restarts":
729 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 738 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
730 optimizer=optimizer, 739 optimizer=optimizer,
731 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 740 num_warmup_steps=warmup_steps,
732 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 741 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
733 num_cycles=args.lr_cycles or math.ceil(math.sqrt( 742 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
734 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), 743 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
@@ -737,7 +746,7 @@ def main():
737 lr_scheduler = get_scheduler( 746 lr_scheduler = get_scheduler(
738 args.lr_scheduler, 747 args.lr_scheduler,
739 optimizer=optimizer, 748 optimizer=optimizer,
740 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 749 num_warmup_steps=warmup_steps,
741 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 750 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
742 ) 751 )
743 752