summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-27 17:19:46 +0100
committerVolpeon <git@volpeon.ink>2022-11-27 17:19:46 +0100
commit4523f853fe47592db30ab3e03e89fb917db68464 (patch)
treecbe8260d25933e14f637c97053d498d0f353eb60
parentMake prompt processor compatible with any model (diff)
downloadtextual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.tar.gz
textual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.tar.bz2
textual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.zip
Generic loading of scheduler (training)
-rw-r--r--dreambooth.py29
-rw-r--r--textual_inversion.py23
2 files changed, 18 insertions, 34 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 2b8a35e..a266d83 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -387,6 +387,7 @@ class Checkpointer:
387 ema_unet, 387 ema_unet,
388 tokenizer, 388 tokenizer,
389 text_encoder, 389 text_encoder,
390 scheduler,
390 output_dir: Path, 391 output_dir: Path,
391 instance_identifier, 392 instance_identifier,
392 placeholder_token, 393 placeholder_token,
@@ -403,6 +404,7 @@ class Checkpointer:
403 self.ema_unet = ema_unet 404 self.ema_unet = ema_unet
404 self.tokenizer = tokenizer 405 self.tokenizer = tokenizer
405 self.text_encoder = text_encoder 406 self.text_encoder = text_encoder
407 self.scheduler = scheduler
406 self.output_dir = output_dir 408 self.output_dir = output_dir
407 self.instance_identifier = instance_identifier 409 self.instance_identifier = instance_identifier
408 self.placeholder_token = placeholder_token 410 self.placeholder_token = placeholder_token
@@ -445,9 +447,7 @@ class Checkpointer:
445 vae=self.vae, 447 vae=self.vae,
446 unet=unwrapped_unet, 448 unet=unwrapped_unet,
447 tokenizer=self.tokenizer, 449 tokenizer=self.tokenizer,
448 scheduler=PNDMScheduler( 450 scheduler=self.scheduler,
449 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
450 ),
451 ) 451 )
452 pipeline.save_pretrained(self.output_dir.joinpath("model")) 452 pipeline.save_pretrained(self.output_dir.joinpath("model"))
453 453
@@ -466,16 +466,12 @@ class Checkpointer:
466 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 466 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
467 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) 467 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
468 468
469 scheduler = DPMSolverMultistepScheduler(
470 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
471 )
472
473 pipeline = VlpnStableDiffusion( 469 pipeline = VlpnStableDiffusion(
474 text_encoder=unwrapped_text_encoder, 470 text_encoder=unwrapped_text_encoder,
475 vae=self.vae, 471 vae=self.vae,
476 unet=unwrapped_unet, 472 unet=unwrapped_unet,
477 tokenizer=self.tokenizer, 473 tokenizer=self.tokenizer,
478 scheduler=scheduler, 474 scheduler=self.scheduler,
479 ).to(self.accelerator.device) 475 ).to(self.accelerator.device)
480 pipeline.set_progress_bar_config(dynamic_ncols=True) 476 pipeline.set_progress_bar_config(dynamic_ncols=True)
481 477
@@ -587,6 +583,9 @@ def main():
587 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 583 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
588 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 584 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
589 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') 585 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
586 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
587 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
588 args.pretrained_model_name_or_path, subfolder='scheduler')
590 589
591 unet.set_use_memory_efficient_attention_xformers(True) 590 unet.set_use_memory_efficient_attention_xformers(True)
592 591
@@ -690,13 +689,6 @@ def main():
690 eps=args.adam_epsilon, 689 eps=args.adam_epsilon,
691 ) 690 )
692 691
693 noise_scheduler = DDPMScheduler(
694 beta_start=0.00085,
695 beta_end=0.012,
696 beta_schedule="scaled_linear",
697 num_train_timesteps=args.noise_timesteps
698 )
699
700 weight_dtype = torch.float32 692 weight_dtype = torch.float32
701 if args.mixed_precision == "fp16": 693 if args.mixed_precision == "fp16":
702 weight_dtype = torch.float16 694 weight_dtype = torch.float16
@@ -755,16 +747,12 @@ def main():
755 for i in range(0, len(missing_data), args.sample_batch_size) 747 for i in range(0, len(missing_data), args.sample_batch_size)
756 ] 748 ]
757 749
758 scheduler = DPMSolverMultistepScheduler(
759 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
760 )
761
762 pipeline = VlpnStableDiffusion( 750 pipeline = VlpnStableDiffusion(
763 text_encoder=text_encoder, 751 text_encoder=text_encoder,
764 vae=vae, 752 vae=vae,
765 unet=unet, 753 unet=unet,
766 tokenizer=tokenizer, 754 tokenizer=tokenizer,
767 scheduler=scheduler, 755 scheduler=checkpoint_scheduler,
768 ).to(accelerator.device) 756 ).to(accelerator.device)
769 pipeline.set_progress_bar_config(dynamic_ncols=True) 757 pipeline.set_progress_bar_config(dynamic_ncols=True)
770 758
@@ -876,6 +864,7 @@ def main():
876 ema_unet=ema_unet, 864 ema_unet=ema_unet,
877 tokenizer=tokenizer, 865 tokenizer=tokenizer,
878 text_encoder=text_encoder, 866 text_encoder=text_encoder,
867 scheduler=checkpoint_scheduler,
879 output_dir=basepath, 868 output_dir=basepath,
880 instance_identifier=instance_identifier, 869 instance_identifier=instance_identifier,
881 placeholder_token=args.placeholder_token, 870 placeholder_token=args.placeholder_token,
diff --git a/textual_inversion.py b/textual_inversion.py
index bf591bc..40ddaab 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -15,7 +15,7 @@ import torch.utils.checkpoint
15from accelerate import Accelerator 15from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
@@ -364,6 +364,7 @@ class Checkpointer:
364 unet, 364 unet,
365 tokenizer, 365 tokenizer,
366 text_encoder, 366 text_encoder,
367 scheduler,
367 instance_identifier, 368 instance_identifier,
368 placeholder_token, 369 placeholder_token,
369 placeholder_token_id, 370 placeholder_token_id,
@@ -379,6 +380,7 @@ class Checkpointer:
379 self.unet = unet 380 self.unet = unet
380 self.tokenizer = tokenizer 381 self.tokenizer = tokenizer
381 self.text_encoder = text_encoder 382 self.text_encoder = text_encoder
383 self.scheduler = scheduler
382 self.instance_identifier = instance_identifier 384 self.instance_identifier = instance_identifier
383 self.placeholder_token = placeholder_token 385 self.placeholder_token = placeholder_token
384 self.placeholder_token_id = placeholder_token_id 386 self.placeholder_token_id = placeholder_token_id
@@ -413,9 +415,6 @@ class Checkpointer:
413 samples_path = Path(self.output_dir).joinpath("samples") 415 samples_path = Path(self.output_dir).joinpath("samples")
414 416
415 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 417 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
416 scheduler = EulerAncestralDiscreteScheduler(
417 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
418 )
419 418
420 # Save a sample image 419 # Save a sample image
421 pipeline = VlpnStableDiffusion( 420 pipeline = VlpnStableDiffusion(
@@ -423,7 +422,7 @@ class Checkpointer:
423 vae=self.vae, 422 vae=self.vae,
424 unet=self.unet, 423 unet=self.unet,
425 tokenizer=self.tokenizer, 424 tokenizer=self.tokenizer,
426 scheduler=scheduler, 425 scheduler=self.scheduler,
427 ).to(self.accelerator.device) 426 ).to(self.accelerator.device)
428 pipeline.set_progress_bar_config(dynamic_ncols=True) 427 pipeline.set_progress_bar_config(dynamic_ncols=True)
429 428
@@ -536,8 +535,10 @@ def main():
536 # Load models and create wrapper for stable diffusion 535 # Load models and create wrapper for stable diffusion
537 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 536 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
538 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 537 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
539 unet = UNet2DConditionModel.from_pretrained( 538 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
540 args.pretrained_model_name_or_path, subfolder='unet') 539 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
540 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
541 args.pretrained_model_name_or_path, subfolder='scheduler')
541 542
542 prompt_processor = PromptProcessor(tokenizer, text_encoder) 543 prompt_processor = PromptProcessor(tokenizer, text_encoder)
543 544
@@ -600,13 +601,6 @@ def main():
600 eps=args.adam_epsilon, 601 eps=args.adam_epsilon,
601 ) 602 )
602 603
603 noise_scheduler = DDPMScheduler(
604 beta_start=0.00085,
605 beta_end=0.012,
606 beta_schedule="scaled_linear",
607 num_train_timesteps=args.noise_timesteps
608 )
609
610 def collate_fn(examples): 604 def collate_fn(examples):
611 prompts = [example["prompts"] for example in examples] 605 prompts = [example["prompts"] for example in examples]
612 nprompts = [example["nprompts"] for example in examples] 606 nprompts = [example["nprompts"] for example in examples]
@@ -772,6 +766,7 @@ def main():
772 unet=unet, 766 unet=unet,
773 tokenizer=tokenizer, 767 tokenizer=tokenizer,
774 text_encoder=text_encoder, 768 text_encoder=text_encoder,
769 scheduler=checkpoint_scheduler,
775 instance_identifier=args.instance_identifier, 770 instance_identifier=args.instance_identifier,
776 placeholder_token=args.placeholder_token, 771 placeholder_token=args.placeholder_token,
777 placeholder_token_id=placeholder_token_id, 772 placeholder_token_id=placeholder_token_id,