summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--dreambooth.py45
-rw-r--r--textual_inversion.py70
3 files changed, 59 insertions, 57 deletions
diff --git a/.gitignore b/.gitignore
index 35b4c22..6b9605f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -161,4 +161,5 @@ cython_debug/
161 161
162output/ 162output/
163conf/ 163conf/
164embeddings/
164v1-inference.yaml* 165v1-inference.yaml*
diff --git a/dreambooth.py b/dreambooth.py
index 24e6091..a26bea7 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -68,7 +68,7 @@ def parse_args():
68 parser.add_argument( 68 parser.add_argument(
69 "--num_class_images", 69 "--num_class_images",
70 type=int, 70 type=int,
71 default=2, 71 default=4,
72 help="How many class images to generate per training image." 72 help="How many class images to generate per training image."
73 ) 73 )
74 parser.add_argument( 74 parser.add_argument(
@@ -106,7 +106,8 @@ def parse_args():
106 parser.add_argument( 106 parser.add_argument(
107 "--num_train_epochs", 107 "--num_train_epochs",
108 type=int, 108 type=int,
109 default=100) 109 default=100
110 )
110 parser.add_argument( 111 parser.add_argument(
111 "--max_train_steps", 112 "--max_train_steps",
112 type=int, 113 type=int,
@@ -293,7 +294,7 @@ class Checkpointer:
293 unet, 294 unet,
294 tokenizer, 295 tokenizer,
295 text_encoder, 296 text_encoder,
296 output_dir, 297 output_dir: Path,
297 instance_identifier, 298 instance_identifier,
298 sample_image_size, 299 sample_image_size,
299 sample_batches, 300 sample_batches,
@@ -321,14 +322,14 @@ class Checkpointer:
321 pipeline = VlpnStableDiffusion( 322 pipeline = VlpnStableDiffusion(
322 text_encoder=self.text_encoder, 323 text_encoder=self.text_encoder,
323 vae=self.vae, 324 vae=self.vae,
324 unet=self.accelerator.unwrap_model(self.unet), 325 unet=unwrapped,
325 tokenizer=self.tokenizer, 326 tokenizer=self.tokenizer,
326 scheduler=PNDMScheduler( 327 scheduler=PNDMScheduler(
327 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 328 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
328 ), 329 ),
329 ) 330 )
330 pipeline.enable_attention_slicing() 331 pipeline.enable_attention_slicing()
331 pipeline.save_pretrained(f"{self.output_dir}/model") 332 pipeline.save_pretrained(self.output_dir.joinpath("model"))
332 333
333 del unwrapped 334 del unwrapped
334 del pipeline 335 del pipeline
@@ -524,7 +525,7 @@ def main():
524 tokenizer=tokenizer, 525 tokenizer=tokenizer,
525 instance_identifier=args.instance_identifier, 526 instance_identifier=args.instance_identifier,
526 class_identifier=args.class_identifier, 527 class_identifier=args.class_identifier,
527 class_subdir="db_cls", 528 class_subdir="cls",
528 num_class_images=args.num_class_images, 529 num_class_images=args.num_class_images,
529 size=args.resolution, 530 size=args.resolution,
530 repeats=args.repeats, 531 repeats=args.repeats,
@@ -580,21 +581,6 @@ def main():
580 train_dataloader = datamodule.train_dataloader() 581 train_dataloader = datamodule.train_dataloader()
581 val_dataloader = datamodule.val_dataloader() 582 val_dataloader = datamodule.val_dataloader()
582 583
583 checkpointer = Checkpointer(
584 datamodule=datamodule,
585 accelerator=accelerator,
586 vae=vae,
587 unet=unet,
588 tokenizer=tokenizer,
589 text_encoder=text_encoder,
590 output_dir=basepath,
591 instance_identifier=args.instance_identifier,
592 sample_image_size=args.sample_image_size,
593 sample_batch_size=args.sample_batch_size,
594 sample_batches=args.sample_batches,
595 seed=args.seed
596 )
597
598 # Scheduler and math around the number of training steps. 584 # Scheduler and math around the number of training steps.
599 overrode_max_train_steps = False 585 overrode_max_train_steps = False
600 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 586 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -613,7 +599,7 @@ def main():
613 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 599 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
614 ) 600 )
615 601
616 # Move vae and unet to device 602 # Move text_encoder and vae to device
617 text_encoder.to(accelerator.device) 603 text_encoder.to(accelerator.device)
618 vae.to(accelerator.device) 604 vae.to(accelerator.device)
619 605
@@ -649,6 +635,21 @@ def main():
649 global_step = 0 635 global_step = 0
650 min_val_loss = np.inf 636 min_val_loss = np.inf
651 637
638 checkpointer = Checkpointer(
639 datamodule=datamodule,
640 accelerator=accelerator,
641 vae=vae,
642 unet=unet,
643 tokenizer=tokenizer,
644 text_encoder=text_encoder,
645 output_dir=basepath,
646 instance_identifier=args.instance_identifier,
647 sample_image_size=args.sample_image_size,
648 sample_batch_size=args.sample_batch_size,
649 sample_batches=args.sample_batches,
650 seed=args.seed
651 )
652
652 if accelerator.is_main_process: 653 if accelerator.is_main_process:
653 checkpointer.save_samples( 654 checkpointer.save_samples(
654 0, 655 0,
diff --git a/textual_inversion.py b/textual_inversion.py
index 86fcdfe..4f2de9e 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -19,7 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json 25import json
@@ -70,7 +70,7 @@ def parse_args():
70 parser.add_argument( 70 parser.add_argument(
71 "--num_class_images", 71 "--num_class_images",
72 type=int, 72 type=int,
73 default=2, 73 default=4,
74 help="How many class images to generate per training image." 74 help="How many class images to generate per training image."
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
@@ -107,7 +107,8 @@ def parse_args():
107 parser.add_argument( 107 parser.add_argument(
108 "--num_train_epochs", 108 "--num_train_epochs",
109 type=int, 109 type=int,
110 default=100) 110 default=100
111 )
111 parser.add_argument( 112 parser.add_argument(
112 "--max_train_steps", 113 "--max_train_steps",
113 type=int, 114 type=int,
@@ -128,7 +129,7 @@ def parse_args():
128 parser.add_argument( 129 parser.add_argument(
129 "--learning_rate", 130 "--learning_rate",
130 type=float, 131 type=float,
131 default=1e-4, 132 default=5e-5,
132 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
133 ) 134 )
134 parser.add_argument( 135 parser.add_argument(
@@ -325,9 +326,10 @@ class Checkpointer:
325 vae, 326 vae,
326 unet, 327 unet,
327 tokenizer, 328 tokenizer,
329 text_encoder,
328 placeholder_token, 330 placeholder_token,
329 placeholder_token_id, 331 placeholder_token_id,
330 output_dir, 332 output_dir: Path,
331 sample_image_size, 333 sample_image_size,
332 sample_batches, 334 sample_batches,
333 sample_batch_size, 335 sample_batch_size,
@@ -338,6 +340,7 @@ class Checkpointer:
338 self.vae = vae 340 self.vae = vae
339 self.unet = unet 341 self.unet = unet
340 self.tokenizer = tokenizer 342 self.tokenizer = tokenizer
343 self.text_encoder = text_encoder
341 self.placeholder_token = placeholder_token 344 self.placeholder_token = placeholder_token
342 self.placeholder_token_id = placeholder_token_id 345 self.placeholder_token_id = placeholder_token_id
343 self.output_dir = output_dir 346 self.output_dir = output_dir
@@ -347,14 +350,14 @@ class Checkpointer:
347 self.sample_batch_size = sample_batch_size 350 self.sample_batch_size = sample_batch_size
348 351
349 @torch.no_grad() 352 @torch.no_grad()
350 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): 353 def checkpoint(self, step, postfix, path=None):
351 print("Saving checkpoint for step %d..." % step) 354 print("Saving checkpoint for step %d..." % step)
352 355
353 if path is None: 356 if path is None:
354 checkpoints_path = f"{self.output_dir}/checkpoints" 357 checkpoints_path = self.output_dir.joinpath("checkpoints")
355 os.makedirs(checkpoints_path, exist_ok=True) 358 checkpoints_path.mkdir(parents=True, exist_ok=True)
356 359
357 unwrapped = self.accelerator.unwrap_model(text_encoder) 360 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
358 361
359 # Save a checkpoint 362 # Save a checkpoint
360 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 363 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
@@ -364,17 +367,16 @@ class Checkpointer:
364 if path is not None: 367 if path is not None:
365 torch.save(learned_embeds_dict, path) 368 torch.save(learned_embeds_dict, path)
366 else: 369 else:
367 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") 370 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
368 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
369 371
370 del unwrapped 372 del unwrapped
371 del learned_embeds 373 del learned_embeds
372 374
373 @torch.no_grad() 375 @torch.no_grad()
374 def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): 376 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
375 samples_path = Path(self.output_dir).joinpath("samples") 377 samples_path = Path(self.output_dir).joinpath("samples")
376 378
377 unwrapped = self.accelerator.unwrap_model(text_encoder) 379 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
378 scheduler = EulerAScheduler( 380 scheduler = EulerAScheduler(
379 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 381 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
380 ) 382 )
@@ -608,7 +610,7 @@ def main():
608 tokenizer=tokenizer, 610 tokenizer=tokenizer,
609 instance_identifier=args.placeholder_token, 611 instance_identifier=args.placeholder_token,
610 class_identifier=args.initializer_token, 612 class_identifier=args.initializer_token,
611 class_subdir="ti_cls", 613 class_subdir="cls",
612 num_class_images=args.num_class_images, 614 num_class_images=args.num_class_images,
613 size=args.resolution, 615 size=args.resolution,
614 repeats=args.repeats, 616 repeats=args.repeats,
@@ -664,21 +666,6 @@ def main():
664 train_dataloader = datamodule.train_dataloader() 666 train_dataloader = datamodule.train_dataloader()
665 val_dataloader = datamodule.val_dataloader() 667 val_dataloader = datamodule.val_dataloader()
666 668
667 checkpointer = Checkpointer(
668 datamodule=datamodule,
669 accelerator=accelerator,
670 vae=vae,
671 unet=unet,
672 tokenizer=tokenizer,
673 placeholder_token=args.placeholder_token,
674 placeholder_token_id=placeholder_token_id,
675 output_dir=basepath,
676 sample_image_size=args.sample_image_size,
677 sample_batch_size=args.sample_batch_size,
678 sample_batches=args.sample_batches,
679 seed=args.seed
680 )
681
682 # Scheduler and math around the number of training steps. 669 # Scheduler and math around the number of training steps.
683 overrode_max_train_steps = False 670 overrode_max_train_steps = False
684 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 671 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -733,10 +720,25 @@ def main():
733 global_step = 0 720 global_step = 0
734 min_val_loss = np.inf 721 min_val_loss = np.inf
735 722
723 checkpointer = Checkpointer(
724 datamodule=datamodule,
725 accelerator=accelerator,
726 vae=vae,
727 unet=unet,
728 tokenizer=tokenizer,
729 text_encoder=text_encoder,
730 placeholder_token=args.placeholder_token,
731 placeholder_token_id=placeholder_token_id,
732 output_dir=basepath,
733 sample_image_size=args.sample_image_size,
734 sample_batch_size=args.sample_batch_size,
735 sample_batches=args.sample_batches,
736 seed=args.seed
737 )
738
736 if accelerator.is_main_process: 739 if accelerator.is_main_process:
737 checkpointer.save_samples( 740 checkpointer.save_samples(
738 0, 741 0,
739 text_encoder,
740 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 742 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
741 743
742 local_progress_bar = tqdm( 744 local_progress_bar = tqdm(
@@ -838,7 +840,7 @@ def main():
838 local_progress_bar.clear() 840 local_progress_bar.clear()
839 global_progress_bar.clear() 841 global_progress_bar.clear()
840 842
841 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) 843 checkpointer.checkpoint(global_step + global_step_offset, "training")
842 save_resume_file(basepath, args, { 844 save_resume_file(basepath, args, {
843 "global_step": global_step + global_step_offset, 845 "global_step": global_step + global_step_offset,
844 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 846 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
@@ -897,13 +899,12 @@ def main():
897 899
898 if min_val_loss > val_loss: 900 if min_val_loss > val_loss:
899 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 901 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
900 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) 902 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
901 min_val_loss = val_loss 903 min_val_loss = val_loss
902 904
903 if sample_checkpoint and accelerator.is_main_process: 905 if sample_checkpoint and accelerator.is_main_process:
904 checkpointer.save_samples( 906 checkpointer.save_samples(
905 global_step + global_step_offset, 907 global_step + global_step_offset,
906 text_encoder,
907 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 908 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
908 909
909 # Create the pipeline using using the trained modules and save it. 910 # Create the pipeline using using the trained modules and save it.
@@ -912,7 +913,6 @@ def main():
912 checkpointer.checkpoint( 913 checkpointer.checkpoint(
913 global_step + global_step_offset, 914 global_step + global_step_offset,
914 "end", 915 "end",
915 text_encoder,
916 path=f"{basepath}/learned_embeds.bin" 916 path=f"{basepath}/learned_embeds.bin"
917 ) 917 )
918 918
@@ -926,7 +926,7 @@ def main():
926 except KeyboardInterrupt: 926 except KeyboardInterrupt:
927 if accelerator.is_main_process: 927 if accelerator.is_main_process:
928 print("Interrupted, saving checkpoint and resume state...") 928 print("Interrupted, saving checkpoint and resume state...")
929 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder) 929 checkpointer.checkpoint(global_step + global_step_offset, "end")
930 save_resume_file(basepath, args, { 930 save_resume_file(basepath, args, {
931 "global_step": global_step + global_step_offset, 931 "global_step": global_step + global_step_offset,
932 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 932 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"