diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-14 20:03:01 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-14 20:03:01 +0200 |
| commit | 6a49074dce78615bce54777fb2be3bfd0dd8f780 (patch) | |
| tree | 0f7dde5ea6b6343fb6e0a527e5ebb2940d418dce /textual_inversion.py | |
| parent | Added support for Aesthetic Gradients (diff) | |
| download | textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.gz textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.bz2 textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.zip | |
Removed aesthetic gradients; training improvements
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 32 |
1 files changed, 14 insertions, 18 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 9d2840d..6627f1f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -57,6 +57,12 @@ def parse_args(): | |||
| 57 | parser.add_argument( | 57 | parser.add_argument( |
| 58 | "--placeholder_token", | 58 | "--placeholder_token", |
| 59 | type=str, | 59 | type=str, |
| 60 | default="<*>", | ||
| 61 | help="A token to use as a placeholder for the concept.", | ||
| 62 | ) | ||
| 63 | parser.add_argument( | ||
| 64 | "--class_identifier", | ||
| 65 | type=str, | ||
| 60 | default=None, | 66 | default=None, |
| 61 | help="A token to use as a placeholder for the concept.", | 67 | help="A token to use as a placeholder for the concept.", |
| 62 | ) | 68 | ) |
| @@ -70,7 +76,7 @@ def parse_args(): | |||
| 70 | "--num_class_images", | 76 | "--num_class_images", |
| 71 | type=int, | 77 | type=int, |
| 72 | default=400, | 78 | default=400, |
| 73 | help="How many class images to generate per training image." | 79 | help="How many class images to generate." |
| 74 | ) | 80 | ) |
| 75 | parser.add_argument( | 81 | parser.add_argument( |
| 76 | "--repeats", | 82 | "--repeats", |
| @@ -344,12 +350,11 @@ class Checkpointer: | |||
| 344 | self.sample_batch_size = sample_batch_size | 350 | self.sample_batch_size = sample_batch_size |
| 345 | 351 | ||
| 346 | @torch.no_grad() | 352 | @torch.no_grad() |
| 347 | def checkpoint(self, step, postfix, path=None): | 353 | def checkpoint(self, step, postfix): |
| 348 | print("Saving checkpoint for step %d..." % step) | 354 | print("Saving checkpoint for step %d..." % step) |
| 349 | 355 | ||
| 350 | if path is None: | 356 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| 351 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 357 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
| 352 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
| 353 | 358 | ||
| 354 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 359 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 355 | 360 | ||
| @@ -358,10 +363,7 @@ class Checkpointer: | |||
| 358 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 363 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} |
| 359 | 364 | ||
| 360 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 365 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) |
| 361 | if path is not None: | 366 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
| 362 | torch.save(learned_embeds_dict, path) | ||
| 363 | else: | ||
| 364 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
| 365 | 367 | ||
| 366 | del unwrapped | 368 | del unwrapped |
| 367 | del learned_embeds | 369 | del learned_embeds |
| @@ -595,7 +597,7 @@ def main(): | |||
| 595 | batch_size=args.train_batch_size, | 597 | batch_size=args.train_batch_size, |
| 596 | tokenizer=tokenizer, | 598 | tokenizer=tokenizer, |
| 597 | instance_identifier=args.placeholder_token, | 599 | instance_identifier=args.placeholder_token, |
| 598 | class_identifier=args.initializer_token, | 600 | class_identifier=args.class_identifier, |
| 599 | class_subdir="cls", | 601 | class_subdir="cls", |
| 600 | num_class_images=args.num_class_images, | 602 | num_class_images=args.num_class_images, |
| 601 | size=args.resolution, | 603 | size=args.resolution, |
| @@ -631,7 +633,7 @@ def main(): | |||
| 631 | with torch.inference_mode(): | 633 | with torch.inference_mode(): |
| 632 | for batch in batched_data: | 634 | for batch in batched_data: |
| 633 | image_name = [p.class_image_path for p in batch] | 635 | image_name = [p.class_image_path for p in batch] |
| 634 | prompt = [p.prompt.format(args.initializer_token) for p in batch] | 636 | prompt = [p.prompt.format(args.class_identifier) for p in batch] |
| 635 | nprompt = [p.nprompt for p in batch] | 637 | nprompt = [p.nprompt for p in batch] |
| 636 | 638 | ||
| 637 | images = pipeline( | 639 | images = pipeline( |
| @@ -898,17 +900,11 @@ def main(): | |||
| 898 | # Create the pipeline using using the trained modules and save it. | 900 | # Create the pipeline using using the trained modules and save it. |
| 899 | if accelerator.is_main_process: | 901 | if accelerator.is_main_process: |
| 900 | print("Finished! Saving final checkpoint and resume state.") | 902 | print("Finished! Saving final checkpoint and resume state.") |
| 901 | checkpointer.checkpoint( | 903 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 902 | global_step + global_step_offset, | ||
| 903 | "end", | ||
| 904 | path=f"{basepath}/learned_embeds.bin" | ||
| 905 | ) | ||
| 906 | |||
| 907 | save_resume_file(basepath, args, { | 904 | save_resume_file(basepath, args, { |
| 908 | "global_step": global_step + global_step_offset, | 905 | "global_step": global_step + global_step_offset, |
| 909 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 906 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
| 910 | }) | 907 | }) |
| 911 | |||
| 912 | accelerator.end_training() | 908 | accelerator.end_training() |
| 913 | 909 | ||
| 914 | except KeyboardInterrupt: | 910 | except KeyboardInterrupt: |
