From 4523f853fe47592db30ab3e03e89fb917db68464 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 27 Nov 2022 17:19:46 +0100 Subject: Generic loading of scheduler (training) --- textual_inversion.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index bf591bc..40ddaab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -15,7 +15,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from PIL import Image from tqdm.auto import tqdm @@ -364,6 +364,7 @@ class Checkpointer: unet, tokenizer, text_encoder, + scheduler, instance_identifier, placeholder_token, placeholder_token_id, @@ -379,6 +380,7 @@ class Checkpointer: self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.scheduler = scheduler self.instance_identifier = instance_identifier self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id @@ -413,9 +415,6 @@ class Checkpointer: samples_path = Path(self.output_dir).joinpath("samples") unwrapped = self.accelerator.unwrap_model(self.text_encoder) - scheduler = EulerAncestralDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) # Save a sample image pipeline = VlpnStableDiffusion( @@ -423,7 +422,7 @@ class Checkpointer: vae=self.vae, unet=self.unet, tokenizer=self.tokenizer, - scheduler=scheduler, + scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) @@ -536,8 +535,10 @@ def main(): # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder='unet') + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') + checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder='scheduler') prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -600,13 +601,6 @@ def main(): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=args.noise_timesteps - ) - def collate_fn(examples): prompts = [example["prompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] @@ -772,6 +766,7 @@ def main(): unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, + scheduler=checkpoint_scheduler, instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, -- cgit v1.2.3-54-g00ecf