diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 23 |
1 files changed, 9 insertions, 14 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index bf591bc..40ddaab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -15,7 +15,7 @@ import torch.utils.checkpoint | |||
15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from PIL import Image | 20 | from PIL import Image |
21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
@@ -364,6 +364,7 @@ class Checkpointer: | |||
364 | unet, | 364 | unet, |
365 | tokenizer, | 365 | tokenizer, |
366 | text_encoder, | 366 | text_encoder, |
367 | scheduler, | ||
367 | instance_identifier, | 368 | instance_identifier, |
368 | placeholder_token, | 369 | placeholder_token, |
369 | placeholder_token_id, | 370 | placeholder_token_id, |
@@ -379,6 +380,7 @@ class Checkpointer: | |||
379 | self.unet = unet | 380 | self.unet = unet |
380 | self.tokenizer = tokenizer | 381 | self.tokenizer = tokenizer |
381 | self.text_encoder = text_encoder | 382 | self.text_encoder = text_encoder |
383 | self.scheduler = scheduler | ||
382 | self.instance_identifier = instance_identifier | 384 | self.instance_identifier = instance_identifier |
383 | self.placeholder_token = placeholder_token | 385 | self.placeholder_token = placeholder_token |
384 | self.placeholder_token_id = placeholder_token_id | 386 | self.placeholder_token_id = placeholder_token_id |
@@ -413,9 +415,6 @@ class Checkpointer: | |||
413 | samples_path = Path(self.output_dir).joinpath("samples") | 415 | samples_path = Path(self.output_dir).joinpath("samples") |
414 | 416 | ||
415 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 417 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
416 | scheduler = EulerAncestralDiscreteScheduler( | ||
417 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
418 | ) | ||
419 | 418 | ||
420 | # Save a sample image | 419 | # Save a sample image |
421 | pipeline = VlpnStableDiffusion( | 420 | pipeline = VlpnStableDiffusion( |
@@ -423,7 +422,7 @@ class Checkpointer: | |||
423 | vae=self.vae, | 422 | vae=self.vae, |
424 | unet=self.unet, | 423 | unet=self.unet, |
425 | tokenizer=self.tokenizer, | 424 | tokenizer=self.tokenizer, |
426 | scheduler=scheduler, | 425 | scheduler=self.scheduler, |
427 | ).to(self.accelerator.device) | 426 | ).to(self.accelerator.device) |
428 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 427 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
429 | 428 | ||
@@ -536,8 +535,10 @@ def main(): | |||
536 | # Load models and create wrapper for stable diffusion | 535 | # Load models and create wrapper for stable diffusion |
537 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 536 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
538 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 537 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
539 | unet = UNet2DConditionModel.from_pretrained( | 538 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
540 | args.pretrained_model_name_or_path, subfolder='unet') | 539 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') |
540 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
541 | args.pretrained_model_name_or_path, subfolder='scheduler') | ||
541 | 542 | ||
542 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 543 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
543 | 544 | ||
@@ -600,13 +601,6 @@ def main(): | |||
600 | eps=args.adam_epsilon, | 601 | eps=args.adam_epsilon, |
601 | ) | 602 | ) |
602 | 603 | ||
603 | noise_scheduler = DDPMScheduler( | ||
604 | beta_start=0.00085, | ||
605 | beta_end=0.012, | ||
606 | beta_schedule="scaled_linear", | ||
607 | num_train_timesteps=args.noise_timesteps | ||
608 | ) | ||
609 | |||
610 | def collate_fn(examples): | 604 | def collate_fn(examples): |
611 | prompts = [example["prompts"] for example in examples] | 605 | prompts = [example["prompts"] for example in examples] |
612 | nprompts = [example["nprompts"] for example in examples] | 606 | nprompts = [example["nprompts"] for example in examples] |
@@ -772,6 +766,7 @@ def main(): | |||
772 | unet=unet, | 766 | unet=unet, |
773 | tokenizer=tokenizer, | 767 | tokenizer=tokenizer, |
774 | text_encoder=text_encoder, | 768 | text_encoder=text_encoder, |
769 | scheduler=checkpoint_scheduler, | ||
775 | instance_identifier=args.instance_identifier, | 770 | instance_identifier=args.instance_identifier, |
776 | placeholder_token=args.placeholder_token, | 771 | placeholder_token=args.placeholder_token, |
777 | placeholder_token_id=placeholder_token_id, | 772 | placeholder_token_id=placeholder_token_id, |