summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-26 11:11:33 +0200
committerVolpeon <git@volpeon.ink>2022-10-26 11:11:33 +0200
commit49463992f48ec25f2ea31b220a6cedac3466467a (patch)
treea58f40e558c14403dbeda687708ef334371694b8 /textual_inversion.py
parentAdvanced datasets (diff)
downloadtextual-inversion-diff-49463992f48ec25f2ea31b220a6cedac3466467a.tar.gz
textual-inversion-diff-49463992f48ec25f2ea31b220a6cedac3466467a.tar.bz2
textual-inversion-diff-49463992f48ec25f2ea31b220a6cedac3466467a.zip
New Euler_a scheduler
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index bcdfd3a..dd7c3bd 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -22,7 +22,7 @@ from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24 24
25from schedulers.scheduling_euler_a import EulerAScheduler 25from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
@@ -398,7 +398,7 @@ class Checkpointer:
398 samples_path = Path(self.output_dir).joinpath("samples") 398 samples_path = Path(self.output_dir).joinpath("samples")
399 399
400 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 400 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
401 scheduler = EulerAScheduler( 401 scheduler = EulerAncestralDiscreteScheduler(
402 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 402 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
403 ) 403 )
404 404
@@ -639,7 +639,7 @@ def main():
639 batched_data = [missing_data[i:i+args.sample_batch_size] 639 batched_data = [missing_data[i:i+args.sample_batch_size]
640 for i in range(0, len(missing_data), args.sample_batch_size)] 640 for i in range(0, len(missing_data), args.sample_batch_size)]
641 641
642 scheduler = EulerAScheduler( 642 scheduler = EulerAncestralDiscreteScheduler(
643 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 643 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
644 ) 644 )
645 645