diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 32 |
1 files changed, 14 insertions, 18 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 9d2840d..6627f1f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -57,6 +57,12 @@ def parse_args(): | |||
57 | parser.add_argument( | 57 | parser.add_argument( |
58 | "--placeholder_token", | 58 | "--placeholder_token", |
59 | type=str, | 59 | type=str, |
60 | default="<*>", | ||
61 | help="A token to use as a placeholder for the concept.", | ||
62 | ) | ||
63 | parser.add_argument( | ||
64 | "--class_identifier", | ||
65 | type=str, | ||
60 | default=None, | 66 | default=None, |
61 | help="A token to use as a placeholder for the concept.", | 67 | help="A token to use as a placeholder for the concept.", |
62 | ) | 68 | ) |
@@ -70,7 +76,7 @@ def parse_args(): | |||
70 | "--num_class_images", | 76 | "--num_class_images", |
71 | type=int, | 77 | type=int, |
72 | default=400, | 78 | default=400, |
73 | help="How many class images to generate per training image." | 79 | help="How many class images to generate." |
74 | ) | 80 | ) |
75 | parser.add_argument( | 81 | parser.add_argument( |
76 | "--repeats", | 82 | "--repeats", |
@@ -344,12 +350,11 @@ class Checkpointer: | |||
344 | self.sample_batch_size = sample_batch_size | 350 | self.sample_batch_size = sample_batch_size |
345 | 351 | ||
346 | @torch.no_grad() | 352 | @torch.no_grad() |
347 | def checkpoint(self, step, postfix, path=None): | 353 | def checkpoint(self, step, postfix): |
348 | print("Saving checkpoint for step %d..." % step) | 354 | print("Saving checkpoint for step %d..." % step) |
349 | 355 | ||
350 | if path is None: | 356 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
351 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 357 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
352 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
353 | 358 | ||
354 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 359 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
355 | 360 | ||
@@ -358,10 +363,7 @@ class Checkpointer: | |||
358 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 363 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} |
359 | 364 | ||
360 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 365 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) |
361 | if path is not None: | 366 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
362 | torch.save(learned_embeds_dict, path) | ||
363 | else: | ||
364 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
365 | 367 | ||
366 | del unwrapped | 368 | del unwrapped |
367 | del learned_embeds | 369 | del learned_embeds |
@@ -595,7 +597,7 @@ def main(): | |||
595 | batch_size=args.train_batch_size, | 597 | batch_size=args.train_batch_size, |
596 | tokenizer=tokenizer, | 598 | tokenizer=tokenizer, |
597 | instance_identifier=args.placeholder_token, | 599 | instance_identifier=args.placeholder_token, |
598 | class_identifier=args.initializer_token, | 600 | class_identifier=args.class_identifier, |
599 | class_subdir="cls", | 601 | class_subdir="cls", |
600 | num_class_images=args.num_class_images, | 602 | num_class_images=args.num_class_images, |
601 | size=args.resolution, | 603 | size=args.resolution, |
@@ -631,7 +633,7 @@ def main(): | |||
631 | with torch.inference_mode(): | 633 | with torch.inference_mode(): |
632 | for batch in batched_data: | 634 | for batch in batched_data: |
633 | image_name = [p.class_image_path for p in batch] | 635 | image_name = [p.class_image_path for p in batch] |
634 | prompt = [p.prompt.format(args.initializer_token) for p in batch] | 636 | prompt = [p.prompt.format(args.class_identifier) for p in batch] |
635 | nprompt = [p.nprompt for p in batch] | 637 | nprompt = [p.nprompt for p in batch] |
636 | 638 | ||
637 | images = pipeline( | 639 | images = pipeline( |
@@ -898,17 +900,11 @@ def main(): | |||
898 | # Create the pipeline using using the trained modules and save it. | 900 | # Create the pipeline using using the trained modules and save it. |
899 | if accelerator.is_main_process: | 901 | if accelerator.is_main_process: |
900 | print("Finished! Saving final checkpoint and resume state.") | 902 | print("Finished! Saving final checkpoint and resume state.") |
901 | checkpointer.checkpoint( | 903 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
902 | global_step + global_step_offset, | ||
903 | "end", | ||
904 | path=f"{basepath}/learned_embeds.bin" | ||
905 | ) | ||
906 | |||
907 | save_resume_file(basepath, args, { | 904 | save_resume_file(basepath, args, { |
908 | "global_step": global_step + global_step_offset, | 905 | "global_step": global_step + global_step_offset, |
909 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 906 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
910 | }) | 907 | }) |
911 | |||
912 | accelerator.end_training() | 908 | accelerator.end_training() |
913 | 909 | ||
914 | except KeyboardInterrupt: | 910 | except KeyboardInterrupt: |