From 49463992f48ec25f2ea31b220a6cedac3466467a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 26 Oct 2022 11:11:33 +0200 Subject: New Euler_a scheduler --- dreambooth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 2c24908..a181293 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -23,7 +23,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from schedulers.scheduling_euler_a import EulerAScheduler +from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from models.clip.prompt import PromptProcessor @@ -443,7 +443,7 @@ class Checkpointer: self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) - scheduler = EulerAScheduler( + scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -715,7 +715,7 @@ def main(): for i in range(0, len(missing_data), args.sample_batch_size) ] - scheduler = EulerAScheduler( + scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) -- cgit v1.2.3-54-g00ecf