diff options
| author | Volpeon <git@volpeon.ink> | 2022-11-27 17:19:46 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-11-27 17:19:46 +0100 |
| commit | 4523f853fe47592db30ab3e03e89fb917db68464 (patch) | |
| tree | cbe8260d25933e14f637c97053d498d0f353eb60 /textual_inversion.py | |
| parent | Make prompt processor compatible with any model (diff) | |
| download | textual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.tar.gz textual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.tar.bz2 textual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.zip | |
Generic loading of scheduler (training)
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 23 |
1 files changed, 9 insertions, 14 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index bf591bc..40ddaab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -15,7 +15,7 @@ import torch.utils.checkpoint | |||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 20 | from PIL import Image | 20 | from PIL import Image |
| 21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| @@ -364,6 +364,7 @@ class Checkpointer: | |||
| 364 | unet, | 364 | unet, |
| 365 | tokenizer, | 365 | tokenizer, |
| 366 | text_encoder, | 366 | text_encoder, |
| 367 | scheduler, | ||
| 367 | instance_identifier, | 368 | instance_identifier, |
| 368 | placeholder_token, | 369 | placeholder_token, |
| 369 | placeholder_token_id, | 370 | placeholder_token_id, |
| @@ -379,6 +380,7 @@ class Checkpointer: | |||
| 379 | self.unet = unet | 380 | self.unet = unet |
| 380 | self.tokenizer = tokenizer | 381 | self.tokenizer = tokenizer |
| 381 | self.text_encoder = text_encoder | 382 | self.text_encoder = text_encoder |
| 383 | self.scheduler = scheduler | ||
| 382 | self.instance_identifier = instance_identifier | 384 | self.instance_identifier = instance_identifier |
| 383 | self.placeholder_token = placeholder_token | 385 | self.placeholder_token = placeholder_token |
| 384 | self.placeholder_token_id = placeholder_token_id | 386 | self.placeholder_token_id = placeholder_token_id |
| @@ -413,9 +415,6 @@ class Checkpointer: | |||
| 413 | samples_path = Path(self.output_dir).joinpath("samples") | 415 | samples_path = Path(self.output_dir).joinpath("samples") |
| 414 | 416 | ||
| 415 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 417 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 416 | scheduler = EulerAncestralDiscreteScheduler( | ||
| 417 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 418 | ) | ||
| 419 | 418 | ||
| 420 | # Save a sample image | 419 | # Save a sample image |
| 421 | pipeline = VlpnStableDiffusion( | 420 | pipeline = VlpnStableDiffusion( |
| @@ -423,7 +422,7 @@ class Checkpointer: | |||
| 423 | vae=self.vae, | 422 | vae=self.vae, |
| 424 | unet=self.unet, | 423 | unet=self.unet, |
| 425 | tokenizer=self.tokenizer, | 424 | tokenizer=self.tokenizer, |
| 426 | scheduler=scheduler, | 425 | scheduler=self.scheduler, |
| 427 | ).to(self.accelerator.device) | 426 | ).to(self.accelerator.device) |
| 428 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 427 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 429 | 428 | ||
| @@ -536,8 +535,10 @@ def main(): | |||
| 536 | # Load models and create wrapper for stable diffusion | 535 | # Load models and create wrapper for stable diffusion |
| 537 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 536 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| 538 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 537 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 539 | unet = UNet2DConditionModel.from_pretrained( | 538 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
| 540 | args.pretrained_model_name_or_path, subfolder='unet') | 539 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') |
| 540 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 541 | args.pretrained_model_name_or_path, subfolder='scheduler') | ||
| 541 | 542 | ||
| 542 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 543 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 543 | 544 | ||
| @@ -600,13 +601,6 @@ def main(): | |||
| 600 | eps=args.adam_epsilon, | 601 | eps=args.adam_epsilon, |
| 601 | ) | 602 | ) |
| 602 | 603 | ||
| 603 | noise_scheduler = DDPMScheduler( | ||
| 604 | beta_start=0.00085, | ||
| 605 | beta_end=0.012, | ||
| 606 | beta_schedule="scaled_linear", | ||
| 607 | num_train_timesteps=args.noise_timesteps | ||
| 608 | ) | ||
| 609 | |||
| 610 | def collate_fn(examples): | 604 | def collate_fn(examples): |
| 611 | prompts = [example["prompts"] for example in examples] | 605 | prompts = [example["prompts"] for example in examples] |
| 612 | nprompts = [example["nprompts"] for example in examples] | 606 | nprompts = [example["nprompts"] for example in examples] |
| @@ -772,6 +766,7 @@ def main(): | |||
| 772 | unet=unet, | 766 | unet=unet, |
| 773 | tokenizer=tokenizer, | 767 | tokenizer=tokenizer, |
| 774 | text_encoder=text_encoder, | 768 | text_encoder=text_encoder, |
| 769 | scheduler=checkpoint_scheduler, | ||
| 775 | instance_identifier=args.instance_identifier, | 770 | instance_identifier=args.instance_identifier, |
| 776 | placeholder_token=args.placeholder_token, | 771 | placeholder_token=args.placeholder_token, |
| 777 | placeholder_token_id=placeholder_token_id, | 772 | placeholder_token_id=placeholder_token_id, |
