summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-27 17:19:46 +0100
committerVolpeon <git@volpeon.ink>2022-11-27 17:19:46 +0100
commit4523f853fe47592db30ab3e03e89fb917db68464 (patch)
treecbe8260d25933e14f637c97053d498d0f353eb60 /dreambooth.py
parentMake prompt processor compatible with any model (diff)
downloadtextual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.tar.gz
textual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.tar.bz2
textual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.zip
Generic loading of scheduler (training)
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py29
1 files changed, 9 insertions, 20 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 2b8a35e..a266d83 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -387,6 +387,7 @@ class Checkpointer:
387 ema_unet, 387 ema_unet,
388 tokenizer, 388 tokenizer,
389 text_encoder, 389 text_encoder,
390 scheduler,
390 output_dir: Path, 391 output_dir: Path,
391 instance_identifier, 392 instance_identifier,
392 placeholder_token, 393 placeholder_token,
@@ -403,6 +404,7 @@ class Checkpointer:
403 self.ema_unet = ema_unet 404 self.ema_unet = ema_unet
404 self.tokenizer = tokenizer 405 self.tokenizer = tokenizer
405 self.text_encoder = text_encoder 406 self.text_encoder = text_encoder
407 self.scheduler = scheduler
406 self.output_dir = output_dir 408 self.output_dir = output_dir
407 self.instance_identifier = instance_identifier 409 self.instance_identifier = instance_identifier
408 self.placeholder_token = placeholder_token 410 self.placeholder_token = placeholder_token
@@ -445,9 +447,7 @@ class Checkpointer:
445 vae=self.vae, 447 vae=self.vae,
446 unet=unwrapped_unet, 448 unet=unwrapped_unet,
447 tokenizer=self.tokenizer, 449 tokenizer=self.tokenizer,
448 scheduler=PNDMScheduler( 450 scheduler=self.scheduler,
449 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
450 ),
451 ) 451 )
452 pipeline.save_pretrained(self.output_dir.joinpath("model")) 452 pipeline.save_pretrained(self.output_dir.joinpath("model"))
453 453
@@ -466,16 +466,12 @@ class Checkpointer:
466 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 466 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
467 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) 467 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
468 468
469 scheduler = DPMSolverMultistepScheduler(
470 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
471 )
472
473 pipeline = VlpnStableDiffusion( 469 pipeline = VlpnStableDiffusion(
474 text_encoder=unwrapped_text_encoder, 470 text_encoder=unwrapped_text_encoder,
475 vae=self.vae, 471 vae=self.vae,
476 unet=unwrapped_unet, 472 unet=unwrapped_unet,
477 tokenizer=self.tokenizer, 473 tokenizer=self.tokenizer,
478 scheduler=scheduler, 474 scheduler=self.scheduler,
479 ).to(self.accelerator.device) 475 ).to(self.accelerator.device)
480 pipeline.set_progress_bar_config(dynamic_ncols=True) 476 pipeline.set_progress_bar_config(dynamic_ncols=True)
481 477
@@ -587,6 +583,9 @@ def main():
587 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 583 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
588 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 584 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
589 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') 585 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
586 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
587 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
588 args.pretrained_model_name_or_path, subfolder='scheduler')
590 589
591 unet.set_use_memory_efficient_attention_xformers(True) 590 unet.set_use_memory_efficient_attention_xformers(True)
592 591
@@ -690,13 +689,6 @@ def main():
690 eps=args.adam_epsilon, 689 eps=args.adam_epsilon,
691 ) 690 )
692 691
693 noise_scheduler = DDPMScheduler(
694 beta_start=0.00085,
695 beta_end=0.012,
696 beta_schedule="scaled_linear",
697 num_train_timesteps=args.noise_timesteps
698 )
699
700 weight_dtype = torch.float32 692 weight_dtype = torch.float32
701 if args.mixed_precision == "fp16": 693 if args.mixed_precision == "fp16":
702 weight_dtype = torch.float16 694 weight_dtype = torch.float16
@@ -755,16 +747,12 @@ def main():
755 for i in range(0, len(missing_data), args.sample_batch_size) 747 for i in range(0, len(missing_data), args.sample_batch_size)
756 ] 748 ]
757 749
758 scheduler = DPMSolverMultistepScheduler(
759 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
760 )
761
762 pipeline = VlpnStableDiffusion( 750 pipeline = VlpnStableDiffusion(
763 text_encoder=text_encoder, 751 text_encoder=text_encoder,
764 vae=vae, 752 vae=vae,
765 unet=unet, 753 unet=unet,
766 tokenizer=tokenizer, 754 tokenizer=tokenizer,
767 scheduler=scheduler, 755 scheduler=checkpoint_scheduler,
768 ).to(accelerator.device) 756 ).to(accelerator.device)
769 pipeline.set_progress_bar_config(dynamic_ncols=True) 757 pipeline.set_progress_bar_config(dynamic_ncols=True)
770 758
@@ -876,6 +864,7 @@ def main():
876 ema_unet=ema_unet, 864 ema_unet=ema_unet,
877 tokenizer=tokenizer, 865 tokenizer=tokenizer,
878 text_encoder=text_encoder, 866 text_encoder=text_encoder,
867 scheduler=checkpoint_scheduler,
879 output_dir=basepath, 868 output_dir=basepath,
880 instance_identifier=instance_identifier, 869 instance_identifier=instance_identifier,
881 placeholder_token=args.placeholder_token, 870 placeholder_token=args.placeholder_token,