diff options
| -rw-r--r-- | dreambooth.py | 56 |
1 files changed, 16 insertions, 40 deletions
diff --git a/dreambooth.py b/dreambooth.py index 1d6735f..675320b 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -440,44 +440,23 @@ class Checkpointer: | |||
| 440 | self.sample_batch_size = sample_batch_size | 440 | self.sample_batch_size = sample_batch_size |
| 441 | 441 | ||
| 442 | @torch.no_grad() | 442 | @torch.no_grad() |
| 443 | def save_embedding(self, step, postfix): | ||
| 444 | if len(self.placeholder_token) == 0: | ||
| 445 | return | ||
| 446 | |||
| 447 | print("Saving checkpoint for step %d..." % step) | ||
| 448 | |||
| 449 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
| 450 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
| 451 | |||
| 452 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | ||
| 453 | |||
| 454 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | ||
| 455 | # Save a checkpoint | ||
| 456 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] | ||
| 457 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | ||
| 458 | |||
| 459 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | ||
| 460 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
| 461 | |||
| 462 | @torch.no_grad() | ||
| 463 | def save_model(self): | 443 | def save_model(self): |
| 464 | print("Saving model...") | 444 | print("Saving model...") |
| 465 | 445 | ||
| 466 | unwrapped_unet = self.accelerator.unwrap_model( | 446 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
| 467 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 447 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 468 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
| 469 | 448 | ||
| 470 | pipeline = VlpnStableDiffusion( | 449 | pipeline = VlpnStableDiffusion( |
| 471 | text_encoder=unwrapped_text_encoder, | 450 | text_encoder=text_encoder, |
| 472 | vae=self.vae, | 451 | vae=self.vae, |
| 473 | unet=unwrapped_unet, | 452 | unet=unet, |
| 474 | tokenizer=self.tokenizer, | 453 | tokenizer=self.tokenizer, |
| 475 | scheduler=self.scheduler, | 454 | scheduler=self.scheduler, |
| 476 | ) | 455 | ) |
| 477 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 456 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
| 478 | 457 | ||
| 479 | del unwrapped_unet | 458 | del unet |
| 480 | del unwrapped_text_encoder | 459 | del text_encoder |
| 481 | del pipeline | 460 | del pipeline |
| 482 | 461 | ||
| 483 | if torch.cuda.is_available(): | 462 | if torch.cuda.is_available(): |
| @@ -487,14 +466,13 @@ class Checkpointer: | |||
| 487 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 466 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 488 | samples_path = Path(self.output_dir).joinpath("samples") | 467 | samples_path = Path(self.output_dir).joinpath("samples") |
| 489 | 468 | ||
| 490 | unwrapped_unet = self.accelerator.unwrap_model( | 469 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
| 491 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 470 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 492 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
| 493 | 471 | ||
| 494 | pipeline = VlpnStableDiffusion( | 472 | pipeline = VlpnStableDiffusion( |
| 495 | text_encoder=unwrapped_text_encoder, | 473 | text_encoder=text_encoder, |
| 496 | vae=self.vae, | 474 | vae=self.vae, |
| 497 | unet=unwrapped_unet, | 475 | unet=unet, |
| 498 | tokenizer=self.tokenizer, | 476 | tokenizer=self.tokenizer, |
| 499 | scheduler=self.scheduler, | 477 | scheduler=self.scheduler, |
| 500 | ).to(self.accelerator.device) | 478 | ).to(self.accelerator.device) |
| @@ -561,8 +539,8 @@ class Checkpointer: | |||
| 561 | del all_samples | 539 | del all_samples |
| 562 | del image_grid | 540 | del image_grid |
| 563 | 541 | ||
| 564 | del unwrapped_unet | 542 | del unet |
| 565 | del unwrapped_text_encoder | 543 | del text_encoder |
| 566 | del pipeline | 544 | del pipeline |
| 567 | del generator | 545 | del generator |
| 568 | del stable_latents | 546 | del stable_latents |
| @@ -981,6 +959,8 @@ def main(): | |||
| 981 | else: | 959 | else: |
| 982 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 960 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 983 | 961 | ||
| 962 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
| 963 | |||
| 984 | if args.num_class_images != 0: | 964 | if args.num_class_images != 0: |
| 985 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 965 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 986 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 966 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| @@ -1037,10 +1017,6 @@ def main(): | |||
| 1037 | global_step += 1 | 1017 | global_step += 1 |
| 1038 | 1018 | ||
| 1039 | if global_step % args.sample_frequency == 0: | 1019 | if global_step % args.sample_frequency == 0: |
| 1040 | local_progress_bar.clear() | ||
| 1041 | global_progress_bar.clear() | ||
| 1042 | |||
| 1043 | checkpointer.save_embedding(global_step, "training") | ||
| 1044 | sample_checkpoint = True | 1020 | sample_checkpoint = True |
| 1045 | 1021 | ||
| 1046 | logs = { | 1022 | logs = { |
| @@ -1093,6 +1069,8 @@ def main(): | |||
| 1093 | else: | 1069 | else: |
| 1094 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 1070 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 1095 | 1071 | ||
| 1072 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
| 1073 | |||
| 1096 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 1074 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 1097 | 1075 | ||
| 1098 | acc = (model_pred == latents).float().mean() | 1076 | acc = (model_pred == latents).float().mean() |
| @@ -1131,7 +1109,6 @@ def main(): | |||
| 1131 | # Create the pipeline using using the trained modules and save it. | 1109 | # Create the pipeline using using the trained modules and save it. |
| 1132 | if accelerator.is_main_process: | 1110 | if accelerator.is_main_process: |
| 1133 | print("Finished! Saving final checkpoint and resume state.") | 1111 | print("Finished! Saving final checkpoint and resume state.") |
| 1134 | checkpointer.save_embedding(global_step, "end") | ||
| 1135 | checkpointer.save_model() | 1112 | checkpointer.save_model() |
| 1136 | 1113 | ||
| 1137 | accelerator.end_training() | 1114 | accelerator.end_training() |
| @@ -1139,7 +1116,6 @@ def main(): | |||
| 1139 | except KeyboardInterrupt: | 1116 | except KeyboardInterrupt: |
| 1140 | if accelerator.is_main_process: | 1117 | if accelerator.is_main_process: |
| 1141 | print("Interrupted, saving checkpoint and resume state...") | 1118 | print("Interrupted, saving checkpoint and resume state...") |
| 1142 | checkpointer.save_embedding(global_step, "end") | ||
| 1143 | checkpointer.save_model() | 1119 | checkpointer.save_model() |
| 1144 | accelerator.end_training() | 1120 | accelerator.end_training() |
| 1145 | quit() | 1121 | quit() |
