diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-07 14:54:44 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-07 14:54:44 +0200 |
| commit | 2af0d47b44fe02269b1378f7691d258d35544bb3 (patch) | |
| tree | cb3250de69e17ad3536e0f548805b7a087a041f2 /textual_inversion.py | |
| parent | Training: Create multiple class images per training image (diff) | |
| download | textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.tar.gz textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.tar.bz2 textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.zip | |
Fix small details
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 70 |
1 files changed, 35 insertions, 35 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 86fcdfe..4f2de9e 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -19,7 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler | |||
| 19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
| 20 | from PIL import Image | 20 | from PIL import Image |
| 21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | import json | 25 | import json |
| @@ -70,7 +70,7 @@ def parse_args(): | |||
| 70 | parser.add_argument( | 70 | parser.add_argument( |
| 71 | "--num_class_images", | 71 | "--num_class_images", |
| 72 | type=int, | 72 | type=int, |
| 73 | default=2, | 73 | default=4, |
| 74 | help="How many class images to generate per training image." | 74 | help="How many class images to generate per training image." |
| 75 | ) | 75 | ) |
| 76 | parser.add_argument( | 76 | parser.add_argument( |
| @@ -107,7 +107,8 @@ def parse_args(): | |||
| 107 | parser.add_argument( | 107 | parser.add_argument( |
| 108 | "--num_train_epochs", | 108 | "--num_train_epochs", |
| 109 | type=int, | 109 | type=int, |
| 110 | default=100) | 110 | default=100 |
| 111 | ) | ||
| 111 | parser.add_argument( | 112 | parser.add_argument( |
| 112 | "--max_train_steps", | 113 | "--max_train_steps", |
| 113 | type=int, | 114 | type=int, |
| @@ -128,7 +129,7 @@ def parse_args(): | |||
| 128 | parser.add_argument( | 129 | parser.add_argument( |
| 129 | "--learning_rate", | 130 | "--learning_rate", |
| 130 | type=float, | 131 | type=float, |
| 131 | default=1e-4, | 132 | default=5e-5, |
| 132 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
| 133 | ) | 134 | ) |
| 134 | parser.add_argument( | 135 | parser.add_argument( |
| @@ -325,9 +326,10 @@ class Checkpointer: | |||
| 325 | vae, | 326 | vae, |
| 326 | unet, | 327 | unet, |
| 327 | tokenizer, | 328 | tokenizer, |
| 329 | text_encoder, | ||
| 328 | placeholder_token, | 330 | placeholder_token, |
| 329 | placeholder_token_id, | 331 | placeholder_token_id, |
| 330 | output_dir, | 332 | output_dir: Path, |
| 331 | sample_image_size, | 333 | sample_image_size, |
| 332 | sample_batches, | 334 | sample_batches, |
| 333 | sample_batch_size, | 335 | sample_batch_size, |
| @@ -338,6 +340,7 @@ class Checkpointer: | |||
| 338 | self.vae = vae | 340 | self.vae = vae |
| 339 | self.unet = unet | 341 | self.unet = unet |
| 340 | self.tokenizer = tokenizer | 342 | self.tokenizer = tokenizer |
| 343 | self.text_encoder = text_encoder | ||
| 341 | self.placeholder_token = placeholder_token | 344 | self.placeholder_token = placeholder_token |
| 342 | self.placeholder_token_id = placeholder_token_id | 345 | self.placeholder_token_id = placeholder_token_id |
| 343 | self.output_dir = output_dir | 346 | self.output_dir = output_dir |
| @@ -347,14 +350,14 @@ class Checkpointer: | |||
| 347 | self.sample_batch_size = sample_batch_size | 350 | self.sample_batch_size = sample_batch_size |
| 348 | 351 | ||
| 349 | @torch.no_grad() | 352 | @torch.no_grad() |
| 350 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): | 353 | def checkpoint(self, step, postfix, path=None): |
| 351 | print("Saving checkpoint for step %d..." % step) | 354 | print("Saving checkpoint for step %d..." % step) |
| 352 | 355 | ||
| 353 | if path is None: | 356 | if path is None: |
| 354 | checkpoints_path = f"{self.output_dir}/checkpoints" | 357 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| 355 | os.makedirs(checkpoints_path, exist_ok=True) | 358 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
| 356 | 359 | ||
| 357 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 360 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 358 | 361 | ||
| 359 | # Save a checkpoint | 362 | # Save a checkpoint |
| 360 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 363 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
| @@ -364,17 +367,16 @@ class Checkpointer: | |||
| 364 | if path is not None: | 367 | if path is not None: |
| 365 | torch.save(learned_embeds_dict, path) | 368 | torch.save(learned_embeds_dict, path) |
| 366 | else: | 369 | else: |
| 367 | torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") | 370 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
| 368 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") | ||
| 369 | 371 | ||
| 370 | del unwrapped | 372 | del unwrapped |
| 371 | del learned_embeds | 373 | del learned_embeds |
| 372 | 374 | ||
| 373 | @torch.no_grad() | 375 | @torch.no_grad() |
| 374 | def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): | 376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
| 375 | samples_path = Path(self.output_dir).joinpath("samples") | 377 | samples_path = Path(self.output_dir).joinpath("samples") |
| 376 | 378 | ||
| 377 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 379 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 378 | scheduler = EulerAScheduler( | 380 | scheduler = EulerAScheduler( |
| 379 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 381 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 380 | ) | 382 | ) |
| @@ -608,7 +610,7 @@ def main(): | |||
| 608 | tokenizer=tokenizer, | 610 | tokenizer=tokenizer, |
| 609 | instance_identifier=args.placeholder_token, | 611 | instance_identifier=args.placeholder_token, |
| 610 | class_identifier=args.initializer_token, | 612 | class_identifier=args.initializer_token, |
| 611 | class_subdir="ti_cls", | 613 | class_subdir="cls", |
| 612 | num_class_images=args.num_class_images, | 614 | num_class_images=args.num_class_images, |
| 613 | size=args.resolution, | 615 | size=args.resolution, |
| 614 | repeats=args.repeats, | 616 | repeats=args.repeats, |
| @@ -664,21 +666,6 @@ def main(): | |||
| 664 | train_dataloader = datamodule.train_dataloader() | 666 | train_dataloader = datamodule.train_dataloader() |
| 665 | val_dataloader = datamodule.val_dataloader() | 667 | val_dataloader = datamodule.val_dataloader() |
| 666 | 668 | ||
| 667 | checkpointer = Checkpointer( | ||
| 668 | datamodule=datamodule, | ||
| 669 | accelerator=accelerator, | ||
| 670 | vae=vae, | ||
| 671 | unet=unet, | ||
| 672 | tokenizer=tokenizer, | ||
| 673 | placeholder_token=args.placeholder_token, | ||
| 674 | placeholder_token_id=placeholder_token_id, | ||
| 675 | output_dir=basepath, | ||
| 676 | sample_image_size=args.sample_image_size, | ||
| 677 | sample_batch_size=args.sample_batch_size, | ||
| 678 | sample_batches=args.sample_batches, | ||
| 679 | seed=args.seed | ||
| 680 | ) | ||
| 681 | |||
| 682 | # Scheduler and math around the number of training steps. | 669 | # Scheduler and math around the number of training steps. |
| 683 | overrode_max_train_steps = False | 670 | overrode_max_train_steps = False |
| 684 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 671 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| @@ -733,10 +720,25 @@ def main(): | |||
| 733 | global_step = 0 | 720 | global_step = 0 |
| 734 | min_val_loss = np.inf | 721 | min_val_loss = np.inf |
| 735 | 722 | ||
| 723 | checkpointer = Checkpointer( | ||
| 724 | datamodule=datamodule, | ||
| 725 | accelerator=accelerator, | ||
| 726 | vae=vae, | ||
| 727 | unet=unet, | ||
| 728 | tokenizer=tokenizer, | ||
| 729 | text_encoder=text_encoder, | ||
| 730 | placeholder_token=args.placeholder_token, | ||
| 731 | placeholder_token_id=placeholder_token_id, | ||
| 732 | output_dir=basepath, | ||
| 733 | sample_image_size=args.sample_image_size, | ||
| 734 | sample_batch_size=args.sample_batch_size, | ||
| 735 | sample_batches=args.sample_batches, | ||
| 736 | seed=args.seed | ||
| 737 | ) | ||
| 738 | |||
| 736 | if accelerator.is_main_process: | 739 | if accelerator.is_main_process: |
| 737 | checkpointer.save_samples( | 740 | checkpointer.save_samples( |
| 738 | 0, | 741 | 0, |
| 739 | text_encoder, | ||
| 740 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 742 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 741 | 743 | ||
| 742 | local_progress_bar = tqdm( | 744 | local_progress_bar = tqdm( |
| @@ -838,7 +840,7 @@ def main(): | |||
| 838 | local_progress_bar.clear() | 840 | local_progress_bar.clear() |
| 839 | global_progress_bar.clear() | 841 | global_progress_bar.clear() |
| 840 | 842 | ||
| 841 | checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) | 843 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
| 842 | save_resume_file(basepath, args, { | 844 | save_resume_file(basepath, args, { |
| 843 | "global_step": global_step + global_step_offset, | 845 | "global_step": global_step + global_step_offset, |
| 844 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 846 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
| @@ -897,13 +899,12 @@ def main(): | |||
| 897 | 899 | ||
| 898 | if min_val_loss > val_loss: | 900 | if min_val_loss > val_loss: |
| 899 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 901 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
| 900 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) | 902 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
| 901 | min_val_loss = val_loss | 903 | min_val_loss = val_loss |
| 902 | 904 | ||
| 903 | if sample_checkpoint and accelerator.is_main_process: | 905 | if sample_checkpoint and accelerator.is_main_process: |
| 904 | checkpointer.save_samples( | 906 | checkpointer.save_samples( |
| 905 | global_step + global_step_offset, | 907 | global_step + global_step_offset, |
| 906 | text_encoder, | ||
| 907 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 908 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 908 | 909 | ||
| 909 | # Create the pipeline using using the trained modules and save it. | 910 | # Create the pipeline using using the trained modules and save it. |
| @@ -912,7 +913,6 @@ def main(): | |||
| 912 | checkpointer.checkpoint( | 913 | checkpointer.checkpoint( |
| 913 | global_step + global_step_offset, | 914 | global_step + global_step_offset, |
| 914 | "end", | 915 | "end", |
| 915 | text_encoder, | ||
| 916 | path=f"{basepath}/learned_embeds.bin" | 916 | path=f"{basepath}/learned_embeds.bin" |
| 917 | ) | 917 | ) |
| 918 | 918 | ||
| @@ -926,7 +926,7 @@ def main(): | |||
| 926 | except KeyboardInterrupt: | 926 | except KeyboardInterrupt: |
| 927 | if accelerator.is_main_process: | 927 | if accelerator.is_main_process: |
| 928 | print("Interrupted, saving checkpoint and resume state...") | 928 | print("Interrupted, saving checkpoint and resume state...") |
| 929 | checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder) | 929 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 930 | save_resume_file(basepath, args, { | 930 | save_resume_file(basepath, args, { |
| 931 | "global_step": global_step + global_step_offset, | 931 | "global_step": global_step + global_step_offset, |
| 932 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 932 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
