From 4523f853fe47592db30ab3e03e89fb917db68464 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 27 Nov 2022 17:19:46 +0100 Subject: Generic loading of scheduler (training) --- dreambooth.py | 29 +++++++++-------------------- textual_inversion.py | 23 +++++++++-------------- 2 files changed, 18 insertions(+), 34 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 2b8a35e..a266d83 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -387,6 +387,7 @@ class Checkpointer: ema_unet, tokenizer, text_encoder, + scheduler, output_dir: Path, instance_identifier, placeholder_token, @@ -403,6 +404,7 @@ class Checkpointer: self.ema_unet = ema_unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.scheduler = scheduler self.output_dir = output_dir self.instance_identifier = instance_identifier self.placeholder_token = placeholder_token @@ -445,9 +447,7 @@ class Checkpointer: vae=self.vae, unet=unwrapped_unet, tokenizer=self.tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), + scheduler=self.scheduler, ) pipeline.save_pretrained(self.output_dir.joinpath("model")) @@ -466,16 +466,12 @@ class Checkpointer: self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) - scheduler = DPMSolverMultistepScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - pipeline = VlpnStableDiffusion( text_encoder=unwrapped_text_encoder, vae=self.vae, unet=unwrapped_unet, tokenizer=self.tokenizer, - scheduler=scheduler, + scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) @@ -587,6 +583,9 @@ def main(): text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') + checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder='scheduler') unet.set_use_memory_efficient_attention_xformers(True) @@ -690,13 +689,6 @@ def main(): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=args.noise_timesteps - ) - weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -755,16 +747,12 @@ def main(): for i in range(0, len(missing_data), args.sample_batch_size) ] - scheduler = DPMSolverMultistepScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, + scheduler=checkpoint_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) @@ -876,6 +864,7 @@ def main(): ema_unet=ema_unet, tokenizer=tokenizer, text_encoder=text_encoder, + scheduler=checkpoint_scheduler, output_dir=basepath, instance_identifier=instance_identifier, placeholder_token=args.placeholder_token, diff --git a/textual_inversion.py b/textual_inversion.py index bf591bc..40ddaab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -15,7 +15,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from PIL import Image from tqdm.auto import tqdm @@ -364,6 +364,7 @@ class Checkpointer: unet, tokenizer, text_encoder, + scheduler, instance_identifier, placeholder_token, placeholder_token_id, @@ -379,6 +380,7 @@ class Checkpointer: self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.scheduler = scheduler self.instance_identifier = instance_identifier self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id @@ -413,9 +415,6 @@ class Checkpointer: samples_path = Path(self.output_dir).joinpath("samples") unwrapped = self.accelerator.unwrap_model(self.text_encoder) - scheduler = EulerAncestralDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) # Save a sample image pipeline = VlpnStableDiffusion( @@ -423,7 +422,7 @@ class Checkpointer: vae=self.vae, unet=self.unet, tokenizer=self.tokenizer, - scheduler=scheduler, + scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) @@ -536,8 +535,10 @@ def main(): # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder='unet') + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') + checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder='scheduler') prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -600,13 +601,6 @@ def main(): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=args.noise_timesteps - ) - def collate_fn(examples): prompts = [example["prompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] @@ -772,6 +766,7 @@ def main(): unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, + scheduler=checkpoint_scheduler, instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, -- cgit v1.2.3-54-g00ecf