From 847ec3b6c43c89ef3649715f86ecfed370b6e442 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 24 Oct 2022 07:34:30 +0200 Subject: Update --- dreambooth.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 72c56cd..5c26f12 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -354,6 +354,8 @@ class Checkpointer: text_encoder, output_dir: Path, instance_identifier, + placeholder_token, + placeholder_token_id, sample_image_size, sample_batches, sample_batch_size, @@ -368,11 +370,35 @@ class Checkpointer: self.text_encoder = text_encoder self.output_dir = output_dir self.instance_identifier = instance_identifier + self.placeholder_token = placeholder_token + self.placeholder_token_id = placeholder_token_id self.sample_image_size = sample_image_size self.seed = seed or torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size + @torch.no_grad() + def save_embedding(self, step, postfix): + if self.placeholder_token_id is None: + return + + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + unwrapped = self.accelerator.unwrap_model(self.text_encoder) + + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] + learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} + + filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + + del unwrapped + del learned_embeds + @torch.no_grad() def save_model(self): print("Saving model...") @@ -567,6 +593,8 @@ def main(): text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.parameters(), )) + else: + placeholder_token_id = None prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -785,6 +813,8 @@ def main(): text_encoder=text_encoder, output_dir=basepath, instance_identifier=instance_identifier, + placeholder_token=args.placeholder_token, + placeholder_token_id=placeholder_token_id, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, @@ -902,6 +932,7 @@ def main(): global_step += 1 if global_step % args.sample_frequency == 0: + checkpointer.save_embedding(global_step, "training") sample_checkpoint = True logs = { @@ -968,6 +999,7 @@ def main(): if min_val_loss > val_loss: accelerator.print( f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + checkpointer.save_embedding(global_step, "milestone") min_val_loss = val_loss if sample_checkpoint and accelerator.is_main_process: @@ -978,6 +1010,7 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") + checkpointer.save_embedding(global_step, "end") checkpointer.save_model() accelerator.end_training() @@ -985,6 +1018,7 @@ def main(): except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") + checkpointer.save_embedding(global_step, "end") checkpointer.save_model() accelerator.end_training() quit() -- cgit v1.2.3-54-g00ecf