summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 2c24908..a181293 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -23,7 +23,7 @@ from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer 23from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from schedulers.scheduling_euler_a import EulerAScheduler 26from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule 28from data.csv import CSVDataModule
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
@@ -443,7 +443,7 @@ class Checkpointer:
443 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 443 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
444 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) 444 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
445 445
446 scheduler = EulerAScheduler( 446 scheduler = EulerAncestralDiscreteScheduler(
447 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 447 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
448 ) 448 )
449 449
@@ -715,7 +715,7 @@ def main():
715 for i in range(0, len(missing_data), args.sample_batch_size) 715 for i in range(0, len(missing_data), args.sample_batch_size)
716 ] 716 ]
717 717
718 scheduler = EulerAScheduler( 718 scheduler = EulerAncestralDiscreteScheduler(
719 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 719 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
720 ) 720 )
721 721