summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-07 14:54:44 +0200
committerVolpeon <git@volpeon.ink>2022-10-07 14:54:44 +0200
commit2af0d47b44fe02269b1378f7691d258d35544bb3 (patch)
treecb3250de69e17ad3536e0f548805b7a087a041f2 /textual_inversion.py
parentTraining: Create multiple class images per training image (diff)
downloadtextual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.tar.gz
textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.tar.bz2
textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.zip
Fix small details
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py70
1 files changed, 35 insertions, 35 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 86fcdfe..4f2de9e 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -19,7 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json 25import json
@@ -70,7 +70,7 @@ def parse_args():
70 parser.add_argument( 70 parser.add_argument(
71 "--num_class_images", 71 "--num_class_images",
72 type=int, 72 type=int,
73 default=2, 73 default=4,
74 help="How many class images to generate per training image." 74 help="How many class images to generate per training image."
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
@@ -107,7 +107,8 @@ def parse_args():
107 parser.add_argument( 107 parser.add_argument(
108 "--num_train_epochs", 108 "--num_train_epochs",
109 type=int, 109 type=int,
110 default=100) 110 default=100
111 )
111 parser.add_argument( 112 parser.add_argument(
112 "--max_train_steps", 113 "--max_train_steps",
113 type=int, 114 type=int,
@@ -128,7 +129,7 @@ def parse_args():
128 parser.add_argument( 129 parser.add_argument(
129 "--learning_rate", 130 "--learning_rate",
130 type=float, 131 type=float,
131 default=1e-4, 132 default=5e-5,
132 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
133 ) 134 )
134 parser.add_argument( 135 parser.add_argument(
@@ -325,9 +326,10 @@ class Checkpointer:
325 vae, 326 vae,
326 unet, 327 unet,
327 tokenizer, 328 tokenizer,
329 text_encoder,
328 placeholder_token, 330 placeholder_token,
329 placeholder_token_id, 331 placeholder_token_id,
330 output_dir, 332 output_dir: Path,
331 sample_image_size, 333 sample_image_size,
332 sample_batches, 334 sample_batches,
333 sample_batch_size, 335 sample_batch_size,
@@ -338,6 +340,7 @@ class Checkpointer:
338 self.vae = vae 340 self.vae = vae
339 self.unet = unet 341 self.unet = unet
340 self.tokenizer = tokenizer 342 self.tokenizer = tokenizer
343 self.text_encoder = text_encoder
341 self.placeholder_token = placeholder_token 344 self.placeholder_token = placeholder_token
342 self.placeholder_token_id = placeholder_token_id 345 self.placeholder_token_id = placeholder_token_id
343 self.output_dir = output_dir 346 self.output_dir = output_dir
@@ -347,14 +350,14 @@ class Checkpointer:
347 self.sample_batch_size = sample_batch_size 350 self.sample_batch_size = sample_batch_size
348 351
349 @torch.no_grad() 352 @torch.no_grad()
350 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): 353 def checkpoint(self, step, postfix, path=None):
351 print("Saving checkpoint for step %d..." % step) 354 print("Saving checkpoint for step %d..." % step)
352 355
353 if path is None: 356 if path is None:
354 checkpoints_path = f"{self.output_dir}/checkpoints" 357 checkpoints_path = self.output_dir.joinpath("checkpoints")
355 os.makedirs(checkpoints_path, exist_ok=True) 358 checkpoints_path.mkdir(parents=True, exist_ok=True)
356 359
357 unwrapped = self.accelerator.unwrap_model(text_encoder) 360 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
358 361
359 # Save a checkpoint 362 # Save a checkpoint
360 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 363 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
@@ -364,17 +367,16 @@ class Checkpointer:
364 if path is not None: 367 if path is not None:
365 torch.save(learned_embeds_dict, path) 368 torch.save(learned_embeds_dict, path)
366 else: 369 else:
367 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") 370 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
368 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
369 371
370 del unwrapped 372 del unwrapped
371 del learned_embeds 373 del learned_embeds
372 374
373 @torch.no_grad() 375 @torch.no_grad()
374 def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): 376 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
375 samples_path = Path(self.output_dir).joinpath("samples") 377 samples_path = Path(self.output_dir).joinpath("samples")
376 378
377 unwrapped = self.accelerator.unwrap_model(text_encoder) 379 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
378 scheduler = EulerAScheduler( 380 scheduler = EulerAScheduler(
379 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 381 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
380 ) 382 )
@@ -608,7 +610,7 @@ def main():
608 tokenizer=tokenizer, 610 tokenizer=tokenizer,
609 instance_identifier=args.placeholder_token, 611 instance_identifier=args.placeholder_token,
610 class_identifier=args.initializer_token, 612 class_identifier=args.initializer_token,
611 class_subdir="ti_cls", 613 class_subdir="cls",
612 num_class_images=args.num_class_images, 614 num_class_images=args.num_class_images,
613 size=args.resolution, 615 size=args.resolution,
614 repeats=args.repeats, 616 repeats=args.repeats,
@@ -664,21 +666,6 @@ def main():
664 train_dataloader = datamodule.train_dataloader() 666 train_dataloader = datamodule.train_dataloader()
665 val_dataloader = datamodule.val_dataloader() 667 val_dataloader = datamodule.val_dataloader()
666 668
667 checkpointer = Checkpointer(
668 datamodule=datamodule,
669 accelerator=accelerator,
670 vae=vae,
671 unet=unet,
672 tokenizer=tokenizer,
673 placeholder_token=args.placeholder_token,
674 placeholder_token_id=placeholder_token_id,
675 output_dir=basepath,
676 sample_image_size=args.sample_image_size,
677 sample_batch_size=args.sample_batch_size,
678 sample_batches=args.sample_batches,
679 seed=args.seed
680 )
681
682 # Scheduler and math around the number of training steps. 669 # Scheduler and math around the number of training steps.
683 overrode_max_train_steps = False 670 overrode_max_train_steps = False
684 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 671 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -733,10 +720,25 @@ def main():
733 global_step = 0 720 global_step = 0
734 min_val_loss = np.inf 721 min_val_loss = np.inf
735 722
723 checkpointer = Checkpointer(
724 datamodule=datamodule,
725 accelerator=accelerator,
726 vae=vae,
727 unet=unet,
728 tokenizer=tokenizer,
729 text_encoder=text_encoder,
730 placeholder_token=args.placeholder_token,
731 placeholder_token_id=placeholder_token_id,
732 output_dir=basepath,
733 sample_image_size=args.sample_image_size,
734 sample_batch_size=args.sample_batch_size,
735 sample_batches=args.sample_batches,
736 seed=args.seed
737 )
738
736 if accelerator.is_main_process: 739 if accelerator.is_main_process:
737 checkpointer.save_samples( 740 checkpointer.save_samples(
738 0, 741 0,
739 text_encoder,
740 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 742 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
741 743
742 local_progress_bar = tqdm( 744 local_progress_bar = tqdm(
@@ -838,7 +840,7 @@ def main():
838 local_progress_bar.clear() 840 local_progress_bar.clear()
839 global_progress_bar.clear() 841 global_progress_bar.clear()
840 842
841 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) 843 checkpointer.checkpoint(global_step + global_step_offset, "training")
842 save_resume_file(basepath, args, { 844 save_resume_file(basepath, args, {
843 "global_step": global_step + global_step_offset, 845 "global_step": global_step + global_step_offset,
844 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 846 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
@@ -897,13 +899,12 @@ def main():
897 899
898 if min_val_loss > val_loss: 900 if min_val_loss > val_loss:
899 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 901 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
900 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) 902 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
901 min_val_loss = val_loss 903 min_val_loss = val_loss
902 904
903 if sample_checkpoint and accelerator.is_main_process: 905 if sample_checkpoint and accelerator.is_main_process:
904 checkpointer.save_samples( 906 checkpointer.save_samples(
905 global_step + global_step_offset, 907 global_step + global_step_offset,
906 text_encoder,
907 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 908 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
908 909
909 # Create the pipeline using using the trained modules and save it. 910 # Create the pipeline using using the trained modules and save it.
@@ -912,7 +913,6 @@ def main():
912 checkpointer.checkpoint( 913 checkpointer.checkpoint(
913 global_step + global_step_offset, 914 global_step + global_step_offset,
914 "end", 915 "end",
915 text_encoder,
916 path=f"{basepath}/learned_embeds.bin" 916 path=f"{basepath}/learned_embeds.bin"
917 ) 917 )
918 918
@@ -926,7 +926,7 @@ def main():
926 except KeyboardInterrupt: 926 except KeyboardInterrupt:
927 if accelerator.is_main_process: 927 if accelerator.is_main_process:
928 print("Interrupted, saving checkpoint and resume state...") 928 print("Interrupted, saving checkpoint and resume state...")
929 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder) 929 checkpointer.checkpoint(global_step + global_step_offset, "end")
930 save_resume_file(basepath, args, { 930 save_resume_file(basepath, args, {
931 "global_step": global_step + global_step_offset, 931 "global_step": global_step + global_step_offset,
932 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 932 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"