From 16b6ea4aef0323871bb44e3ae06733b314a3d615 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 11 Dec 2022 16:23:48 +0100 Subject: Remove embedding checkpoints from Dreambooth training --- dreambooth.py | 56 ++++++++++++++++---------------------------------------- 1 file changed, 16 insertions(+), 40 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 1d6735f..675320b 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -439,45 +439,24 @@ class Checkpointer: self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size - @torch.no_grad() - def save_embedding(self, step, postfix): - if len(self.placeholder_token) == 0: - return - - print("Saving checkpoint for step %d..." % step) - - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - - unwrapped = self.accelerator.unwrap_model(self.text_encoder) - - for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): - # Save a checkpoint - learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] - learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} - - filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) - torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) - @torch.no_grad() def save_model(self): print("Saving model...") - unwrapped_unet = self.accelerator.unwrap_model( - self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) - unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) + unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) + text_encoder = self.accelerator.unwrap_model(self.text_encoder) pipeline = VlpnStableDiffusion( - text_encoder=unwrapped_text_encoder, + text_encoder=text_encoder, vae=self.vae, - unet=unwrapped_unet, + unet=unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ) pipeline.save_pretrained(self.output_dir.joinpath("model")) - del unwrapped_unet - del unwrapped_text_encoder + del unet + del text_encoder del pipeline if torch.cuda.is_available(): @@ -487,14 +466,13 @@ class Checkpointer: def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") - unwrapped_unet = self.accelerator.unwrap_model( - self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) - unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) + unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) + text_encoder = self.accelerator.unwrap_model(self.text_encoder) pipeline = VlpnStableDiffusion( - text_encoder=unwrapped_text_encoder, + text_encoder=text_encoder, vae=self.vae, - unet=unwrapped_unet, + unet=unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ).to(self.accelerator.device) @@ -561,8 +539,8 @@ class Checkpointer: del all_samples del image_grid - del unwrapped_unet - del unwrapped_text_encoder + del unet + del text_encoder del pipeline del generator del stable_latents @@ -981,6 +959,8 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + del timesteps, noise, latents, noisy_latents, encoder_hidden_states + if args.num_class_images != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) @@ -1037,10 +1017,6 @@ def main(): global_step += 1 if global_step % args.sample_frequency == 0: - local_progress_bar.clear() - global_progress_bar.clear() - - checkpointer.save_embedding(global_step, "training") sample_checkpoint = True logs = { @@ -1093,6 +1069,8 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + del timesteps, noise, latents, noisy_latents, encoder_hidden_states + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") acc = (model_pred == latents).float().mean() @@ -1131,7 +1109,6 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") - checkpointer.save_embedding(global_step, "end") checkpointer.save_model() accelerator.end_training() @@ -1139,7 +1116,6 @@ def main(): except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") - checkpointer.save_embedding(global_step, "end") checkpointer.save_model() accelerator.end_training() quit() -- cgit v1.2.3-54-g00ecf