From 6a49074dce78615bce54777fb2be3bfd0dd8f780 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 14 Oct 2022 20:03:01 +0200 Subject: Removed aesthetic gradients; training improvements --- dreambooth_plus.py | 59 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 9 deletions(-) (limited to 'dreambooth_plus.py') diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 7996bc2..b5ec2fc 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -58,6 +58,12 @@ def parse_args(): parser.add_argument( "--placeholder_token", type=str, + default="<*>", + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--class_identifier", + type=str, default=None, help="A token to use as a placeholder for the concept.", ) @@ -71,7 +77,7 @@ def parse_args(): "--num_class_images", type=int, default=400, - help="How many class images to generate per training image." + help="How many class images to generate." ) parser.add_argument( "--repeats", @@ -112,7 +118,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1600, + default=2300, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -135,7 +141,7 @@ def parse_args(): parser.add_argument( "--learning_rate_text", type=float, - default=5e-4, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -221,6 +227,12 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--checkpoint_frequency", + type=int, + default=500, + help="How often to save a checkpoint and sample image", + ) parser.add_argument( "--sample_frequency", type=int, @@ -352,7 +364,26 @@ class Checkpointer: self.sample_batch_size = sample_batch_size @torch.no_grad() - def checkpoint(self): + def checkpoint(self, step, postfix): + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + unwrapped = self.accelerator.unwrap_model(self.text_encoder) + + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] + learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} + + filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + + del unwrapped + del learned_embeds + + @torch.no_grad() + def save_model(self): print("Saving model...") unwrapped_unet = self.accelerator.unwrap_model( @@ -612,7 +643,7 @@ def main(): batch_size=args.train_batch_size, tokenizer=tokenizer, instance_identifier=args.placeholder_token, - class_identifier=args.initializer_token, + class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, @@ -648,7 +679,7 @@ def main(): with torch.inference_mode(): for batch in batched_data: image_name = [p.class_image_path for p in batch] - prompt = [p.prompt.format(args.initializer_token) for p in batch] + prompt = [p.prompt.format(args.class_identifier) for p in batch] nprompt = [p.nprompt for p in batch] images = pipeline( @@ -842,6 +873,12 @@ def main(): if global_step % args.sample_frequency == 0: sample_checkpoint = True + if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: + local_progress_bar.clear() + global_progress_bar.clear() + + checkpointer.checkpoint(global_step + global_step_offset, "training") + logs = { "train/loss": loss, "lr/unet": lr_scheduler.get_last_lr()[0], @@ -903,6 +940,9 @@ def main(): global_progress_bar.clear() if min_val_loss > val_loss: + accelerator.print( + f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + checkpointer.checkpoint(global_step + global_step_offset, "milestone") min_val_loss = val_loss if sample_checkpoint and accelerator.is_main_process: @@ -913,14 +953,15 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") - checkpointer.checkpoint() - + checkpointer.checkpoint(global_step + global_step_offset, "end") + checkpointer.save_model() accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") - checkpointer.checkpoint() + checkpointer.checkpoint(global_step + global_step_offset, "end") + checkpointer.save_model() accelerator.end_training() quit() -- cgit v1.2.3-54-g00ecf