diff options
-rw-r--r-- | dreambooth.py | 29 | ||||
-rw-r--r-- | textual_inversion.py | 23 |
2 files changed, 18 insertions, 34 deletions
diff --git a/dreambooth.py b/dreambooth.py index 2b8a35e..a266d83 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -387,6 +387,7 @@ class Checkpointer: | |||
387 | ema_unet, | 387 | ema_unet, |
388 | tokenizer, | 388 | tokenizer, |
389 | text_encoder, | 389 | text_encoder, |
390 | scheduler, | ||
390 | output_dir: Path, | 391 | output_dir: Path, |
391 | instance_identifier, | 392 | instance_identifier, |
392 | placeholder_token, | 393 | placeholder_token, |
@@ -403,6 +404,7 @@ class Checkpointer: | |||
403 | self.ema_unet = ema_unet | 404 | self.ema_unet = ema_unet |
404 | self.tokenizer = tokenizer | 405 | self.tokenizer = tokenizer |
405 | self.text_encoder = text_encoder | 406 | self.text_encoder = text_encoder |
407 | self.scheduler = scheduler | ||
406 | self.output_dir = output_dir | 408 | self.output_dir = output_dir |
407 | self.instance_identifier = instance_identifier | 409 | self.instance_identifier = instance_identifier |
408 | self.placeholder_token = placeholder_token | 410 | self.placeholder_token = placeholder_token |
@@ -445,9 +447,7 @@ class Checkpointer: | |||
445 | vae=self.vae, | 447 | vae=self.vae, |
446 | unet=unwrapped_unet, | 448 | unet=unwrapped_unet, |
447 | tokenizer=self.tokenizer, | 449 | tokenizer=self.tokenizer, |
448 | scheduler=PNDMScheduler( | 450 | scheduler=self.scheduler, |
449 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
450 | ), | ||
451 | ) | 451 | ) |
452 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 452 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
453 | 453 | ||
@@ -466,16 +466,12 @@ class Checkpointer: | |||
466 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 466 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) |
467 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 467 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
468 | 468 | ||
469 | scheduler = DPMSolverMultistepScheduler( | ||
470 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
471 | ) | ||
472 | |||
473 | pipeline = VlpnStableDiffusion( | 469 | pipeline = VlpnStableDiffusion( |
474 | text_encoder=unwrapped_text_encoder, | 470 | text_encoder=unwrapped_text_encoder, |
475 | vae=self.vae, | 471 | vae=self.vae, |
476 | unet=unwrapped_unet, | 472 | unet=unwrapped_unet, |
477 | tokenizer=self.tokenizer, | 473 | tokenizer=self.tokenizer, |
478 | scheduler=scheduler, | 474 | scheduler=self.scheduler, |
479 | ).to(self.accelerator.device) | 475 | ).to(self.accelerator.device) |
480 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 476 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
481 | 477 | ||
@@ -587,6 +583,9 @@ def main(): | |||
587 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 583 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
588 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 584 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
589 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 585 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
586 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') | ||
587 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
588 | args.pretrained_model_name_or_path, subfolder='scheduler') | ||
590 | 589 | ||
591 | unet.set_use_memory_efficient_attention_xformers(True) | 590 | unet.set_use_memory_efficient_attention_xformers(True) |
592 | 591 | ||
@@ -690,13 +689,6 @@ def main(): | |||
690 | eps=args.adam_epsilon, | 689 | eps=args.adam_epsilon, |
691 | ) | 690 | ) |
692 | 691 | ||
693 | noise_scheduler = DDPMScheduler( | ||
694 | beta_start=0.00085, | ||
695 | beta_end=0.012, | ||
696 | beta_schedule="scaled_linear", | ||
697 | num_train_timesteps=args.noise_timesteps | ||
698 | ) | ||
699 | |||
700 | weight_dtype = torch.float32 | 692 | weight_dtype = torch.float32 |
701 | if args.mixed_precision == "fp16": | 693 | if args.mixed_precision == "fp16": |
702 | weight_dtype = torch.float16 | 694 | weight_dtype = torch.float16 |
@@ -755,16 +747,12 @@ def main(): | |||
755 | for i in range(0, len(missing_data), args.sample_batch_size) | 747 | for i in range(0, len(missing_data), args.sample_batch_size) |
756 | ] | 748 | ] |
757 | 749 | ||
758 | scheduler = DPMSolverMultistepScheduler( | ||
759 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
760 | ) | ||
761 | |||
762 | pipeline = VlpnStableDiffusion( | 750 | pipeline = VlpnStableDiffusion( |
763 | text_encoder=text_encoder, | 751 | text_encoder=text_encoder, |
764 | vae=vae, | 752 | vae=vae, |
765 | unet=unet, | 753 | unet=unet, |
766 | tokenizer=tokenizer, | 754 | tokenizer=tokenizer, |
767 | scheduler=scheduler, | 755 | scheduler=checkpoint_scheduler, |
768 | ).to(accelerator.device) | 756 | ).to(accelerator.device) |
769 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 757 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
770 | 758 | ||
@@ -876,6 +864,7 @@ def main(): | |||
876 | ema_unet=ema_unet, | 864 | ema_unet=ema_unet, |
877 | tokenizer=tokenizer, | 865 | tokenizer=tokenizer, |
878 | text_encoder=text_encoder, | 866 | text_encoder=text_encoder, |
867 | scheduler=checkpoint_scheduler, | ||
879 | output_dir=basepath, | 868 | output_dir=basepath, |
880 | instance_identifier=instance_identifier, | 869 | instance_identifier=instance_identifier, |
881 | placeholder_token=args.placeholder_token, | 870 | placeholder_token=args.placeholder_token, |
diff --git a/textual_inversion.py b/textual_inversion.py index bf591bc..40ddaab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -15,7 +15,7 @@ import torch.utils.checkpoint | |||
15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from PIL import Image | 20 | from PIL import Image |
21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
@@ -364,6 +364,7 @@ class Checkpointer: | |||
364 | unet, | 364 | unet, |
365 | tokenizer, | 365 | tokenizer, |
366 | text_encoder, | 366 | text_encoder, |
367 | scheduler, | ||
367 | instance_identifier, | 368 | instance_identifier, |
368 | placeholder_token, | 369 | placeholder_token, |
369 | placeholder_token_id, | 370 | placeholder_token_id, |
@@ -379,6 +380,7 @@ class Checkpointer: | |||
379 | self.unet = unet | 380 | self.unet = unet |
380 | self.tokenizer = tokenizer | 381 | self.tokenizer = tokenizer |
381 | self.text_encoder = text_encoder | 382 | self.text_encoder = text_encoder |
383 | self.scheduler = scheduler | ||
382 | self.instance_identifier = instance_identifier | 384 | self.instance_identifier = instance_identifier |
383 | self.placeholder_token = placeholder_token | 385 | self.placeholder_token = placeholder_token |
384 | self.placeholder_token_id = placeholder_token_id | 386 | self.placeholder_token_id = placeholder_token_id |
@@ -413,9 +415,6 @@ class Checkpointer: | |||
413 | samples_path = Path(self.output_dir).joinpath("samples") | 415 | samples_path = Path(self.output_dir).joinpath("samples") |
414 | 416 | ||
415 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 417 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
416 | scheduler = EulerAncestralDiscreteScheduler( | ||
417 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
418 | ) | ||
419 | 418 | ||
420 | # Save a sample image | 419 | # Save a sample image |
421 | pipeline = VlpnStableDiffusion( | 420 | pipeline = VlpnStableDiffusion( |
@@ -423,7 +422,7 @@ class Checkpointer: | |||
423 | vae=self.vae, | 422 | vae=self.vae, |
424 | unet=self.unet, | 423 | unet=self.unet, |
425 | tokenizer=self.tokenizer, | 424 | tokenizer=self.tokenizer, |
426 | scheduler=scheduler, | 425 | scheduler=self.scheduler, |
427 | ).to(self.accelerator.device) | 426 | ).to(self.accelerator.device) |
428 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 427 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
429 | 428 | ||
@@ -536,8 +535,10 @@ def main(): | |||
536 | # Load models and create wrapper for stable diffusion | 535 | # Load models and create wrapper for stable diffusion |
537 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 536 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
538 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 537 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
539 | unet = UNet2DConditionModel.from_pretrained( | 538 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
540 | args.pretrained_model_name_or_path, subfolder='unet') | 539 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') |
540 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
541 | args.pretrained_model_name_or_path, subfolder='scheduler') | ||
541 | 542 | ||
542 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 543 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
543 | 544 | ||
@@ -600,13 +601,6 @@ def main(): | |||
600 | eps=args.adam_epsilon, | 601 | eps=args.adam_epsilon, |
601 | ) | 602 | ) |
602 | 603 | ||
603 | noise_scheduler = DDPMScheduler( | ||
604 | beta_start=0.00085, | ||
605 | beta_end=0.012, | ||
606 | beta_schedule="scaled_linear", | ||
607 | num_train_timesteps=args.noise_timesteps | ||
608 | ) | ||
609 | |||
610 | def collate_fn(examples): | 604 | def collate_fn(examples): |
611 | prompts = [example["prompts"] for example in examples] | 605 | prompts = [example["prompts"] for example in examples] |
612 | nprompts = [example["nprompts"] for example in examples] | 606 | nprompts = [example["nprompts"] for example in examples] |
@@ -772,6 +766,7 @@ def main(): | |||
772 | unet=unet, | 766 | unet=unet, |
773 | tokenizer=tokenizer, | 767 | tokenizer=tokenizer, |
774 | text_encoder=text_encoder, | 768 | text_encoder=text_encoder, |
769 | scheduler=checkpoint_scheduler, | ||
775 | instance_identifier=args.instance_identifier, | 770 | instance_identifier=args.instance_identifier, |
776 | placeholder_token=args.placeholder_token, | 771 | placeholder_token=args.placeholder_token, |
777 | placeholder_token_id=placeholder_token_id, | 772 | placeholder_token_id=placeholder_token_id, |