From 6a49074dce78615bce54777fb2be3bfd0dd8f780 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 14 Oct 2022 20:03:01 +0200 Subject: Removed aesthetic gradients; training improvements --- textual_inversion.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 9d2840d..6627f1f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -57,6 +57,12 @@ def parse_args(): parser.add_argument( "--placeholder_token", type=str, + default="<*>", + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--class_identifier", + type=str, default=None, help="A token to use as a placeholder for the concept.", ) @@ -70,7 +76,7 @@ def parse_args(): "--num_class_images", type=int, default=400, - help="How many class images to generate per training image." + help="How many class images to generate." ) parser.add_argument( "--repeats", @@ -344,12 +350,11 @@ class Checkpointer: self.sample_batch_size = sample_batch_size @torch.no_grad() - def checkpoint(self, step, postfix, path=None): + def checkpoint(self, step, postfix): print("Saving checkpoint for step %d..." % step) - if path is None: - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) unwrapped = self.accelerator.unwrap_model(self.text_encoder) @@ -358,10 +363,7 @@ class Checkpointer: learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) - if path is not None: - torch.save(learned_embeds_dict, path) - else: - torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) del unwrapped del learned_embeds @@ -595,7 +597,7 @@ def main(): batch_size=args.train_batch_size, tokenizer=tokenizer, instance_identifier=args.placeholder_token, - class_identifier=args.initializer_token, + class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, @@ -631,7 +633,7 @@ def main(): with torch.inference_mode(): for batch in batched_data: image_name = [p.class_image_path for p in batch] - prompt = [p.prompt.format(args.initializer_token) for p in batch] + prompt = [p.prompt.format(args.class_identifier) for p in batch] nprompt = [p.nprompt for p in batch] images = pipeline( @@ -898,17 +900,11 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") - checkpointer.checkpoint( - global_step + global_step_offset, - "end", - path=f"{basepath}/learned_embeds.bin" - ) - + checkpointer.checkpoint(global_step + global_step_offset, "end") save_resume_file(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) - accelerator.end_training() except KeyboardInterrupt: -- cgit v1.2.3-54-g00ecf