diff options
-rw-r--r-- | dreambooth.py | 56 |
1 files changed, 16 insertions, 40 deletions
diff --git a/dreambooth.py b/dreambooth.py index 1d6735f..675320b 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -440,44 +440,23 @@ class Checkpointer: | |||
440 | self.sample_batch_size = sample_batch_size | 440 | self.sample_batch_size = sample_batch_size |
441 | 441 | ||
442 | @torch.no_grad() | 442 | @torch.no_grad() |
443 | def save_embedding(self, step, postfix): | ||
444 | if len(self.placeholder_token) == 0: | ||
445 | return | ||
446 | |||
447 | print("Saving checkpoint for step %d..." % step) | ||
448 | |||
449 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
450 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
451 | |||
452 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | ||
453 | |||
454 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | ||
455 | # Save a checkpoint | ||
456 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] | ||
457 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | ||
458 | |||
459 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | ||
460 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
461 | |||
462 | @torch.no_grad() | ||
463 | def save_model(self): | 443 | def save_model(self): |
464 | print("Saving model...") | 444 | print("Saving model...") |
465 | 445 | ||
466 | unwrapped_unet = self.accelerator.unwrap_model( | 446 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
467 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 447 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
468 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
469 | 448 | ||
470 | pipeline = VlpnStableDiffusion( | 449 | pipeline = VlpnStableDiffusion( |
471 | text_encoder=unwrapped_text_encoder, | 450 | text_encoder=text_encoder, |
472 | vae=self.vae, | 451 | vae=self.vae, |
473 | unet=unwrapped_unet, | 452 | unet=unet, |
474 | tokenizer=self.tokenizer, | 453 | tokenizer=self.tokenizer, |
475 | scheduler=self.scheduler, | 454 | scheduler=self.scheduler, |
476 | ) | 455 | ) |
477 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 456 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
478 | 457 | ||
479 | del unwrapped_unet | 458 | del unet |
480 | del unwrapped_text_encoder | 459 | del text_encoder |
481 | del pipeline | 460 | del pipeline |
482 | 461 | ||
483 | if torch.cuda.is_available(): | 462 | if torch.cuda.is_available(): |
@@ -487,14 +466,13 @@ class Checkpointer: | |||
487 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 466 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
488 | samples_path = Path(self.output_dir).joinpath("samples") | 467 | samples_path = Path(self.output_dir).joinpath("samples") |
489 | 468 | ||
490 | unwrapped_unet = self.accelerator.unwrap_model( | 469 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
491 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 470 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
492 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
493 | 471 | ||
494 | pipeline = VlpnStableDiffusion( | 472 | pipeline = VlpnStableDiffusion( |
495 | text_encoder=unwrapped_text_encoder, | 473 | text_encoder=text_encoder, |
496 | vae=self.vae, | 474 | vae=self.vae, |
497 | unet=unwrapped_unet, | 475 | unet=unet, |
498 | tokenizer=self.tokenizer, | 476 | tokenizer=self.tokenizer, |
499 | scheduler=self.scheduler, | 477 | scheduler=self.scheduler, |
500 | ).to(self.accelerator.device) | 478 | ).to(self.accelerator.device) |
@@ -561,8 +539,8 @@ class Checkpointer: | |||
561 | del all_samples | 539 | del all_samples |
562 | del image_grid | 540 | del image_grid |
563 | 541 | ||
564 | del unwrapped_unet | 542 | del unet |
565 | del unwrapped_text_encoder | 543 | del text_encoder |
566 | del pipeline | 544 | del pipeline |
567 | del generator | 545 | del generator |
568 | del stable_latents | 546 | del stable_latents |
@@ -981,6 +959,8 @@ def main(): | |||
981 | else: | 959 | else: |
982 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 960 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
983 | 961 | ||
962 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
963 | |||
984 | if args.num_class_images != 0: | 964 | if args.num_class_images != 0: |
985 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 965 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
986 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 966 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
@@ -1037,10 +1017,6 @@ def main(): | |||
1037 | global_step += 1 | 1017 | global_step += 1 |
1038 | 1018 | ||
1039 | if global_step % args.sample_frequency == 0: | 1019 | if global_step % args.sample_frequency == 0: |
1040 | local_progress_bar.clear() | ||
1041 | global_progress_bar.clear() | ||
1042 | |||
1043 | checkpointer.save_embedding(global_step, "training") | ||
1044 | sample_checkpoint = True | 1020 | sample_checkpoint = True |
1045 | 1021 | ||
1046 | logs = { | 1022 | logs = { |
@@ -1093,6 +1069,8 @@ def main(): | |||
1093 | else: | 1069 | else: |
1094 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 1070 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
1095 | 1071 | ||
1072 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
1073 | |||
1096 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 1074 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
1097 | 1075 | ||
1098 | acc = (model_pred == latents).float().mean() | 1076 | acc = (model_pred == latents).float().mean() |
@@ -1131,7 +1109,6 @@ def main(): | |||
1131 | # Create the pipeline using using the trained modules and save it. | 1109 | # Create the pipeline using using the trained modules and save it. |
1132 | if accelerator.is_main_process: | 1110 | if accelerator.is_main_process: |
1133 | print("Finished! Saving final checkpoint and resume state.") | 1111 | print("Finished! Saving final checkpoint and resume state.") |
1134 | checkpointer.save_embedding(global_step, "end") | ||
1135 | checkpointer.save_model() | 1112 | checkpointer.save_model() |
1136 | 1113 | ||
1137 | accelerator.end_training() | 1114 | accelerator.end_training() |
@@ -1139,7 +1116,6 @@ def main(): | |||
1139 | except KeyboardInterrupt: | 1116 | except KeyboardInterrupt: |
1140 | if accelerator.is_main_process: | 1117 | if accelerator.is_main_process: |
1141 | print("Interrupted, saving checkpoint and resume state...") | 1118 | print("Interrupted, saving checkpoint and resume state...") |
1142 | checkpointer.save_embedding(global_step, "end") | ||
1143 | checkpointer.save_model() | 1119 | checkpointer.save_model() |
1144 | accelerator.end_training() | 1120 | accelerator.end_training() |
1145 | quit() | 1121 | quit() |