summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-27 17:19:46 +0100
committerVolpeon <git@volpeon.ink>2022-11-27 17:19:46 +0100
commit4523f853fe47592db30ab3e03e89fb917db68464 (patch)
treecbe8260d25933e14f637c97053d498d0f353eb60 /textual_inversion.py
parentMake prompt processor compatible with any model (diff)
downloadtextual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.tar.gz
textual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.tar.bz2
textual-inversion-diff-4523f853fe47592db30ab3e03e89fb917db68464.zip
Generic loading of scheduler (training)
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py23
1 files changed, 9 insertions, 14 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index bf591bc..40ddaab 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -15,7 +15,7 @@ import torch.utils.checkpoint
15from accelerate import Accelerator 15from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
@@ -364,6 +364,7 @@ class Checkpointer:
364 unet, 364 unet,
365 tokenizer, 365 tokenizer,
366 text_encoder, 366 text_encoder,
367 scheduler,
367 instance_identifier, 368 instance_identifier,
368 placeholder_token, 369 placeholder_token,
369 placeholder_token_id, 370 placeholder_token_id,
@@ -379,6 +380,7 @@ class Checkpointer:
379 self.unet = unet 380 self.unet = unet
380 self.tokenizer = tokenizer 381 self.tokenizer = tokenizer
381 self.text_encoder = text_encoder 382 self.text_encoder = text_encoder
383 self.scheduler = scheduler
382 self.instance_identifier = instance_identifier 384 self.instance_identifier = instance_identifier
383 self.placeholder_token = placeholder_token 385 self.placeholder_token = placeholder_token
384 self.placeholder_token_id = placeholder_token_id 386 self.placeholder_token_id = placeholder_token_id
@@ -413,9 +415,6 @@ class Checkpointer:
413 samples_path = Path(self.output_dir).joinpath("samples") 415 samples_path = Path(self.output_dir).joinpath("samples")
414 416
415 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 417 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
416 scheduler = EulerAncestralDiscreteScheduler(
417 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
418 )
419 418
420 # Save a sample image 419 # Save a sample image
421 pipeline = VlpnStableDiffusion( 420 pipeline = VlpnStableDiffusion(
@@ -423,7 +422,7 @@ class Checkpointer:
423 vae=self.vae, 422 vae=self.vae,
424 unet=self.unet, 423 unet=self.unet,
425 tokenizer=self.tokenizer, 424 tokenizer=self.tokenizer,
426 scheduler=scheduler, 425 scheduler=self.scheduler,
427 ).to(self.accelerator.device) 426 ).to(self.accelerator.device)
428 pipeline.set_progress_bar_config(dynamic_ncols=True) 427 pipeline.set_progress_bar_config(dynamic_ncols=True)
429 428
@@ -536,8 +535,10 @@ def main():
536 # Load models and create wrapper for stable diffusion 535 # Load models and create wrapper for stable diffusion
537 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 536 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
538 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 537 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
539 unet = UNet2DConditionModel.from_pretrained( 538 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
540 args.pretrained_model_name_or_path, subfolder='unet') 539 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
540 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
541 args.pretrained_model_name_or_path, subfolder='scheduler')
541 542
542 prompt_processor = PromptProcessor(tokenizer, text_encoder) 543 prompt_processor = PromptProcessor(tokenizer, text_encoder)
543 544
@@ -600,13 +601,6 @@ def main():
600 eps=args.adam_epsilon, 601 eps=args.adam_epsilon,
601 ) 602 )
602 603
603 noise_scheduler = DDPMScheduler(
604 beta_start=0.00085,
605 beta_end=0.012,
606 beta_schedule="scaled_linear",
607 num_train_timesteps=args.noise_timesteps
608 )
609
610 def collate_fn(examples): 604 def collate_fn(examples):
611 prompts = [example["prompts"] for example in examples] 605 prompts = [example["prompts"] for example in examples]
612 nprompts = [example["nprompts"] for example in examples] 606 nprompts = [example["nprompts"] for example in examples]
@@ -772,6 +766,7 @@ def main():
772 unet=unet, 766 unet=unet,
773 tokenizer=tokenizer, 767 tokenizer=tokenizer,
774 text_encoder=text_encoder, 768 text_encoder=text_encoder,
769 scheduler=checkpoint_scheduler,
775 instance_identifier=args.instance_identifier, 770 instance_identifier=args.instance_identifier,
776 placeholder_token=args.placeholder_token, 771 placeholder_token=args.placeholder_token,
777 placeholder_token_id=placeholder_token_id, 772 placeholder_token_id=placeholder_token_id,