diff options
| -rw-r--r-- | dreambooth.py | 29 | ||||
| -rw-r--r-- | textual_inversion.py | 23 |
2 files changed, 18 insertions, 34 deletions
diff --git a/dreambooth.py b/dreambooth.py index 2b8a35e..a266d83 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -387,6 +387,7 @@ class Checkpointer: | |||
| 387 | ema_unet, | 387 | ema_unet, |
| 388 | tokenizer, | 388 | tokenizer, |
| 389 | text_encoder, | 389 | text_encoder, |
| 390 | scheduler, | ||
| 390 | output_dir: Path, | 391 | output_dir: Path, |
| 391 | instance_identifier, | 392 | instance_identifier, |
| 392 | placeholder_token, | 393 | placeholder_token, |
| @@ -403,6 +404,7 @@ class Checkpointer: | |||
| 403 | self.ema_unet = ema_unet | 404 | self.ema_unet = ema_unet |
| 404 | self.tokenizer = tokenizer | 405 | self.tokenizer = tokenizer |
| 405 | self.text_encoder = text_encoder | 406 | self.text_encoder = text_encoder |
| 407 | self.scheduler = scheduler | ||
| 406 | self.output_dir = output_dir | 408 | self.output_dir = output_dir |
| 407 | self.instance_identifier = instance_identifier | 409 | self.instance_identifier = instance_identifier |
| 408 | self.placeholder_token = placeholder_token | 410 | self.placeholder_token = placeholder_token |
| @@ -445,9 +447,7 @@ class Checkpointer: | |||
| 445 | vae=self.vae, | 447 | vae=self.vae, |
| 446 | unet=unwrapped_unet, | 448 | unet=unwrapped_unet, |
| 447 | tokenizer=self.tokenizer, | 449 | tokenizer=self.tokenizer, |
| 448 | scheduler=PNDMScheduler( | 450 | scheduler=self.scheduler, |
| 449 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
| 450 | ), | ||
| 451 | ) | 451 | ) |
| 452 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 452 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
| 453 | 453 | ||
| @@ -466,16 +466,12 @@ class Checkpointer: | |||
| 466 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 466 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) |
| 467 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 467 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 468 | 468 | ||
| 469 | scheduler = DPMSolverMultistepScheduler( | ||
| 470 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 471 | ) | ||
| 472 | |||
| 473 | pipeline = VlpnStableDiffusion( | 469 | pipeline = VlpnStableDiffusion( |
| 474 | text_encoder=unwrapped_text_encoder, | 470 | text_encoder=unwrapped_text_encoder, |
| 475 | vae=self.vae, | 471 | vae=self.vae, |
| 476 | unet=unwrapped_unet, | 472 | unet=unwrapped_unet, |
| 477 | tokenizer=self.tokenizer, | 473 | tokenizer=self.tokenizer, |
| 478 | scheduler=scheduler, | 474 | scheduler=self.scheduler, |
| 479 | ).to(self.accelerator.device) | 475 | ).to(self.accelerator.device) |
| 480 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 476 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 481 | 477 | ||
| @@ -587,6 +583,9 @@ def main(): | |||
| 587 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 583 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| 588 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 584 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 589 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 585 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
| 586 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') | ||
| 587 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 588 | args.pretrained_model_name_or_path, subfolder='scheduler') | ||
| 590 | 589 | ||
| 591 | unet.set_use_memory_efficient_attention_xformers(True) | 590 | unet.set_use_memory_efficient_attention_xformers(True) |
| 592 | 591 | ||
| @@ -690,13 +689,6 @@ def main(): | |||
| 690 | eps=args.adam_epsilon, | 689 | eps=args.adam_epsilon, |
| 691 | ) | 690 | ) |
| 692 | 691 | ||
| 693 | noise_scheduler = DDPMScheduler( | ||
| 694 | beta_start=0.00085, | ||
| 695 | beta_end=0.012, | ||
| 696 | beta_schedule="scaled_linear", | ||
| 697 | num_train_timesteps=args.noise_timesteps | ||
| 698 | ) | ||
| 699 | |||
| 700 | weight_dtype = torch.float32 | 692 | weight_dtype = torch.float32 |
| 701 | if args.mixed_precision == "fp16": | 693 | if args.mixed_precision == "fp16": |
| 702 | weight_dtype = torch.float16 | 694 | weight_dtype = torch.float16 |
| @@ -755,16 +747,12 @@ def main(): | |||
| 755 | for i in range(0, len(missing_data), args.sample_batch_size) | 747 | for i in range(0, len(missing_data), args.sample_batch_size) |
| 756 | ] | 748 | ] |
| 757 | 749 | ||
| 758 | scheduler = DPMSolverMultistepScheduler( | ||
| 759 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 760 | ) | ||
| 761 | |||
| 762 | pipeline = VlpnStableDiffusion( | 750 | pipeline = VlpnStableDiffusion( |
| 763 | text_encoder=text_encoder, | 751 | text_encoder=text_encoder, |
| 764 | vae=vae, | 752 | vae=vae, |
| 765 | unet=unet, | 753 | unet=unet, |
| 766 | tokenizer=tokenizer, | 754 | tokenizer=tokenizer, |
| 767 | scheduler=scheduler, | 755 | scheduler=checkpoint_scheduler, |
| 768 | ).to(accelerator.device) | 756 | ).to(accelerator.device) |
| 769 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 757 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 770 | 758 | ||
| @@ -876,6 +864,7 @@ def main(): | |||
| 876 | ema_unet=ema_unet, | 864 | ema_unet=ema_unet, |
| 877 | tokenizer=tokenizer, | 865 | tokenizer=tokenizer, |
| 878 | text_encoder=text_encoder, | 866 | text_encoder=text_encoder, |
| 867 | scheduler=checkpoint_scheduler, | ||
| 879 | output_dir=basepath, | 868 | output_dir=basepath, |
| 880 | instance_identifier=instance_identifier, | 869 | instance_identifier=instance_identifier, |
| 881 | placeholder_token=args.placeholder_token, | 870 | placeholder_token=args.placeholder_token, |
diff --git a/textual_inversion.py b/textual_inversion.py index bf591bc..40ddaab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -15,7 +15,7 @@ import torch.utils.checkpoint | |||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 20 | from PIL import Image | 20 | from PIL import Image |
| 21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| @@ -364,6 +364,7 @@ class Checkpointer: | |||
| 364 | unet, | 364 | unet, |
| 365 | tokenizer, | 365 | tokenizer, |
| 366 | text_encoder, | 366 | text_encoder, |
| 367 | scheduler, | ||
| 367 | instance_identifier, | 368 | instance_identifier, |
| 368 | placeholder_token, | 369 | placeholder_token, |
| 369 | placeholder_token_id, | 370 | placeholder_token_id, |
| @@ -379,6 +380,7 @@ class Checkpointer: | |||
| 379 | self.unet = unet | 380 | self.unet = unet |
| 380 | self.tokenizer = tokenizer | 381 | self.tokenizer = tokenizer |
| 381 | self.text_encoder = text_encoder | 382 | self.text_encoder = text_encoder |
| 383 | self.scheduler = scheduler | ||
| 382 | self.instance_identifier = instance_identifier | 384 | self.instance_identifier = instance_identifier |
| 383 | self.placeholder_token = placeholder_token | 385 | self.placeholder_token = placeholder_token |
| 384 | self.placeholder_token_id = placeholder_token_id | 386 | self.placeholder_token_id = placeholder_token_id |
| @@ -413,9 +415,6 @@ class Checkpointer: | |||
| 413 | samples_path = Path(self.output_dir).joinpath("samples") | 415 | samples_path = Path(self.output_dir).joinpath("samples") |
| 414 | 416 | ||
| 415 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 417 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 416 | scheduler = EulerAncestralDiscreteScheduler( | ||
| 417 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 418 | ) | ||
| 419 | 418 | ||
| 420 | # Save a sample image | 419 | # Save a sample image |
| 421 | pipeline = VlpnStableDiffusion( | 420 | pipeline = VlpnStableDiffusion( |
| @@ -423,7 +422,7 @@ class Checkpointer: | |||
| 423 | vae=self.vae, | 422 | vae=self.vae, |
| 424 | unet=self.unet, | 423 | unet=self.unet, |
| 425 | tokenizer=self.tokenizer, | 424 | tokenizer=self.tokenizer, |
| 426 | scheduler=scheduler, | 425 | scheduler=self.scheduler, |
| 427 | ).to(self.accelerator.device) | 426 | ).to(self.accelerator.device) |
| 428 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 427 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 429 | 428 | ||
| @@ -536,8 +535,10 @@ def main(): | |||
| 536 | # Load models and create wrapper for stable diffusion | 535 | # Load models and create wrapper for stable diffusion |
| 537 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 536 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| 538 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 537 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 539 | unet = UNet2DConditionModel.from_pretrained( | 538 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
| 540 | args.pretrained_model_name_or_path, subfolder='unet') | 539 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') |
| 540 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 541 | args.pretrained_model_name_or_path, subfolder='scheduler') | ||
| 541 | 542 | ||
| 542 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 543 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 543 | 544 | ||
| @@ -600,13 +601,6 @@ def main(): | |||
| 600 | eps=args.adam_epsilon, | 601 | eps=args.adam_epsilon, |
| 601 | ) | 602 | ) |
| 602 | 603 | ||
| 603 | noise_scheduler = DDPMScheduler( | ||
| 604 | beta_start=0.00085, | ||
| 605 | beta_end=0.012, | ||
| 606 | beta_schedule="scaled_linear", | ||
| 607 | num_train_timesteps=args.noise_timesteps | ||
| 608 | ) | ||
| 609 | |||
| 610 | def collate_fn(examples): | 604 | def collate_fn(examples): |
| 611 | prompts = [example["prompts"] for example in examples] | 605 | prompts = [example["prompts"] for example in examples] |
| 612 | nprompts = [example["nprompts"] for example in examples] | 606 | nprompts = [example["nprompts"] for example in examples] |
| @@ -772,6 +766,7 @@ def main(): | |||
| 772 | unet=unet, | 766 | unet=unet, |
| 773 | tokenizer=tokenizer, | 767 | tokenizer=tokenizer, |
| 774 | text_encoder=text_encoder, | 768 | text_encoder=text_encoder, |
| 769 | scheduler=checkpoint_scheduler, | ||
| 775 | instance_identifier=args.instance_identifier, | 770 | instance_identifier=args.instance_identifier, |
| 776 | placeholder_token=args.placeholder_token, | 771 | placeholder_token=args.placeholder_token, |
| 777 | placeholder_token_id=placeholder_token_id, | 772 | placeholder_token_id=placeholder_token_id, |
