summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-11 16:23:48 +0100
committerVolpeon <git@volpeon.ink>2022-12-11 16:23:48 +0100
commit16b6ea4aef0323871bb44e3ae06733b314a3d615 (patch)
tree3005eab9eab3a039b88af061d0a8adc2446b2209
parentPackage updates (diff)
downloadtextual-inversion-diff-16b6ea4aef0323871bb44e3ae06733b314a3d615.tar.gz
textual-inversion-diff-16b6ea4aef0323871bb44e3ae06733b314a3d615.tar.bz2
textual-inversion-diff-16b6ea4aef0323871bb44e3ae06733b314a3d615.zip
Remove embedding checkpoints from Dreambooth training
-rw-r--r--dreambooth.py56
1 files changed, 16 insertions, 40 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 1d6735f..675320b 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -440,44 +440,23 @@ class Checkpointer:
440 self.sample_batch_size = sample_batch_size 440 self.sample_batch_size = sample_batch_size
441 441
442 @torch.no_grad() 442 @torch.no_grad()
443 def save_embedding(self, step, postfix):
444 if len(self.placeholder_token) == 0:
445 return
446
447 print("Saving checkpoint for step %d..." % step)
448
449 checkpoints_path = self.output_dir.joinpath("checkpoints")
450 checkpoints_path.mkdir(parents=True, exist_ok=True)
451
452 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
453
454 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
455 # Save a checkpoint
456 learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id]
457 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
458
459 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
460 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
461
462 @torch.no_grad()
463 def save_model(self): 443 def save_model(self):
464 print("Saving model...") 444 print("Saving model...")
465 445
466 unwrapped_unet = self.accelerator.unwrap_model( 446 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
467 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 447 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
468 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
469 448
470 pipeline = VlpnStableDiffusion( 449 pipeline = VlpnStableDiffusion(
471 text_encoder=unwrapped_text_encoder, 450 text_encoder=text_encoder,
472 vae=self.vae, 451 vae=self.vae,
473 unet=unwrapped_unet, 452 unet=unet,
474 tokenizer=self.tokenizer, 453 tokenizer=self.tokenizer,
475 scheduler=self.scheduler, 454 scheduler=self.scheduler,
476 ) 455 )
477 pipeline.save_pretrained(self.output_dir.joinpath("model")) 456 pipeline.save_pretrained(self.output_dir.joinpath("model"))
478 457
479 del unwrapped_unet 458 del unet
480 del unwrapped_text_encoder 459 del text_encoder
481 del pipeline 460 del pipeline
482 461
483 if torch.cuda.is_available(): 462 if torch.cuda.is_available():
@@ -487,14 +466,13 @@ class Checkpointer:
487 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 466 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
488 samples_path = Path(self.output_dir).joinpath("samples") 467 samples_path = Path(self.output_dir).joinpath("samples")
489 468
490 unwrapped_unet = self.accelerator.unwrap_model( 469 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
491 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 470 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
492 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
493 471
494 pipeline = VlpnStableDiffusion( 472 pipeline = VlpnStableDiffusion(
495 text_encoder=unwrapped_text_encoder, 473 text_encoder=text_encoder,
496 vae=self.vae, 474 vae=self.vae,
497 unet=unwrapped_unet, 475 unet=unet,
498 tokenizer=self.tokenizer, 476 tokenizer=self.tokenizer,
499 scheduler=self.scheduler, 477 scheduler=self.scheduler,
500 ).to(self.accelerator.device) 478 ).to(self.accelerator.device)
@@ -561,8 +539,8 @@ class Checkpointer:
561 del all_samples 539 del all_samples
562 del image_grid 540 del image_grid
563 541
564 del unwrapped_unet 542 del unet
565 del unwrapped_text_encoder 543 del text_encoder
566 del pipeline 544 del pipeline
567 del generator 545 del generator
568 del stable_latents 546 del stable_latents
@@ -981,6 +959,8 @@ def main():
981 else: 959 else:
982 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 960 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
983 961
962 del timesteps, noise, latents, noisy_latents, encoder_hidden_states
963
984 if args.num_class_images != 0: 964 if args.num_class_images != 0:
985 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 965 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
986 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 966 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
@@ -1037,10 +1017,6 @@ def main():
1037 global_step += 1 1017 global_step += 1
1038 1018
1039 if global_step % args.sample_frequency == 0: 1019 if global_step % args.sample_frequency == 0:
1040 local_progress_bar.clear()
1041 global_progress_bar.clear()
1042
1043 checkpointer.save_embedding(global_step, "training")
1044 sample_checkpoint = True 1020 sample_checkpoint = True
1045 1021
1046 logs = { 1022 logs = {
@@ -1093,6 +1069,8 @@ def main():
1093 else: 1069 else:
1094 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 1070 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1095 1071
1072 del timesteps, noise, latents, noisy_latents, encoder_hidden_states
1073
1096 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1074 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1097 1075
1098 acc = (model_pred == latents).float().mean() 1076 acc = (model_pred == latents).float().mean()
@@ -1131,7 +1109,6 @@ def main():
1131 # Create the pipeline using using the trained modules and save it. 1109 # Create the pipeline using using the trained modules and save it.
1132 if accelerator.is_main_process: 1110 if accelerator.is_main_process:
1133 print("Finished! Saving final checkpoint and resume state.") 1111 print("Finished! Saving final checkpoint and resume state.")
1134 checkpointer.save_embedding(global_step, "end")
1135 checkpointer.save_model() 1112 checkpointer.save_model()
1136 1113
1137 accelerator.end_training() 1114 accelerator.end_training()
@@ -1139,7 +1116,6 @@ def main():
1139 except KeyboardInterrupt: 1116 except KeyboardInterrupt:
1140 if accelerator.is_main_process: 1117 if accelerator.is_main_process:
1141 print("Interrupted, saving checkpoint and resume state...") 1118 print("Interrupted, saving checkpoint and resume state...")
1142 checkpointer.save_embedding(global_step, "end")
1143 checkpointer.save_model() 1119 checkpointer.save_model()
1144 accelerator.end_training() 1120 accelerator.end_training()
1145 quit() 1121 quit()