summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-14 20:03:01 +0200
committerVolpeon <git@volpeon.ink>2022-10-14 20:03:01 +0200
commit6a49074dce78615bce54777fb2be3bfd0dd8f780 (patch)
tree0f7dde5ea6b6343fb6e0a527e5ebb2940d418dce /textual_inversion.py
parentAdded support for Aesthetic Gradients (diff)
downloadtextual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.gz
textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.bz2
textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.zip
Removed aesthetic gradients; training improvements
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py32
1 files changed, 14 insertions, 18 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 9d2840d..6627f1f 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -57,6 +57,12 @@ def parse_args():
57 parser.add_argument( 57 parser.add_argument(
58 "--placeholder_token", 58 "--placeholder_token",
59 type=str, 59 type=str,
60 default="<*>",
61 help="A token to use as a placeholder for the concept.",
62 )
63 parser.add_argument(
64 "--class_identifier",
65 type=str,
60 default=None, 66 default=None,
61 help="A token to use as a placeholder for the concept.", 67 help="A token to use as a placeholder for the concept.",
62 ) 68 )
@@ -70,7 +76,7 @@ def parse_args():
70 "--num_class_images", 76 "--num_class_images",
71 type=int, 77 type=int,
72 default=400, 78 default=400,
73 help="How many class images to generate per training image." 79 help="How many class images to generate."
74 ) 80 )
75 parser.add_argument( 81 parser.add_argument(
76 "--repeats", 82 "--repeats",
@@ -344,12 +350,11 @@ class Checkpointer:
344 self.sample_batch_size = sample_batch_size 350 self.sample_batch_size = sample_batch_size
345 351
346 @torch.no_grad() 352 @torch.no_grad()
347 def checkpoint(self, step, postfix, path=None): 353 def checkpoint(self, step, postfix):
348 print("Saving checkpoint for step %d..." % step) 354 print("Saving checkpoint for step %d..." % step)
349 355
350 if path is None: 356 checkpoints_path = self.output_dir.joinpath("checkpoints")
351 checkpoints_path = self.output_dir.joinpath("checkpoints") 357 checkpoints_path.mkdir(parents=True, exist_ok=True)
352 checkpoints_path.mkdir(parents=True, exist_ok=True)
353 358
354 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 359 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
355 360
@@ -358,10 +363,7 @@ class Checkpointer:
358 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 363 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
359 364
360 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) 365 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
361 if path is not None: 366 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
362 torch.save(learned_embeds_dict, path)
363 else:
364 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
365 367
366 del unwrapped 368 del unwrapped
367 del learned_embeds 369 del learned_embeds
@@ -595,7 +597,7 @@ def main():
595 batch_size=args.train_batch_size, 597 batch_size=args.train_batch_size,
596 tokenizer=tokenizer, 598 tokenizer=tokenizer,
597 instance_identifier=args.placeholder_token, 599 instance_identifier=args.placeholder_token,
598 class_identifier=args.initializer_token, 600 class_identifier=args.class_identifier,
599 class_subdir="cls", 601 class_subdir="cls",
600 num_class_images=args.num_class_images, 602 num_class_images=args.num_class_images,
601 size=args.resolution, 603 size=args.resolution,
@@ -631,7 +633,7 @@ def main():
631 with torch.inference_mode(): 633 with torch.inference_mode():
632 for batch in batched_data: 634 for batch in batched_data:
633 image_name = [p.class_image_path for p in batch] 635 image_name = [p.class_image_path for p in batch]
634 prompt = [p.prompt.format(args.initializer_token) for p in batch] 636 prompt = [p.prompt.format(args.class_identifier) for p in batch]
635 nprompt = [p.nprompt for p in batch] 637 nprompt = [p.nprompt for p in batch]
636 638
637 images = pipeline( 639 images = pipeline(
@@ -898,17 +900,11 @@ def main():
898 # Create the pipeline using using the trained modules and save it. 900 # Create the pipeline using using the trained modules and save it.
899 if accelerator.is_main_process: 901 if accelerator.is_main_process:
900 print("Finished! Saving final checkpoint and resume state.") 902 print("Finished! Saving final checkpoint and resume state.")
901 checkpointer.checkpoint( 903 checkpointer.checkpoint(global_step + global_step_offset, "end")
902 global_step + global_step_offset,
903 "end",
904 path=f"{basepath}/learned_embeds.bin"
905 )
906
907 save_resume_file(basepath, args, { 904 save_resume_file(basepath, args, {
908 "global_step": global_step + global_step_offset, 905 "global_step": global_step + global_step_offset,
909 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 906 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
910 }) 907 })
911
912 accelerator.end_training() 908 accelerator.end_training()
913 909
914 except KeyboardInterrupt: 910 except KeyboardInterrupt: