From 4523f853fe47592db30ab3e03e89fb917db68464 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 27 Nov 2022 17:19:46 +0100 Subject: Generic loading of scheduler (training) --- dreambooth.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 2b8a35e..a266d83 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -387,6 +387,7 @@ class Checkpointer: ema_unet, tokenizer, text_encoder, + scheduler, output_dir: Path, instance_identifier, placeholder_token, @@ -403,6 +404,7 @@ class Checkpointer: self.ema_unet = ema_unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.scheduler = scheduler self.output_dir = output_dir self.instance_identifier = instance_identifier self.placeholder_token = placeholder_token @@ -445,9 +447,7 @@ class Checkpointer: vae=self.vae, unet=unwrapped_unet, tokenizer=self.tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), + scheduler=self.scheduler, ) pipeline.save_pretrained(self.output_dir.joinpath("model")) @@ -466,16 +466,12 @@ class Checkpointer: self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) - scheduler = DPMSolverMultistepScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - pipeline = VlpnStableDiffusion( text_encoder=unwrapped_text_encoder, vae=self.vae, unet=unwrapped_unet, tokenizer=self.tokenizer, - scheduler=scheduler, + scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) @@ -587,6 +583,9 @@ def main(): text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') + checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder='scheduler') unet.set_use_memory_efficient_attention_xformers(True) @@ -690,13 +689,6 @@ def main(): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=args.noise_timesteps - ) - weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -755,16 +747,12 @@ def main(): for i in range(0, len(missing_data), args.sample_batch_size) ] - scheduler = DPMSolverMultistepScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, + scheduler=checkpoint_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) @@ -876,6 +864,7 @@ def main(): ema_unet=ema_unet, tokenizer=tokenizer, text_encoder=text_encoder, + scheduler=checkpoint_scheduler, output_dir=basepath, instance_identifier=instance_identifier, placeholder_token=args.placeholder_token, -- cgit v1.2.3-54-g00ecf