summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
commitf23fd5184b8ba4ec04506495f4a61726e50756f7 (patch)
treed4c5666b291316ed95437cc1c917b03ef3b679da /textual_inversion.py
parentAdded negative prompt support for training scripts (diff)
downloadtextual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.gz
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.bz2
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.zip
Small perf improvements
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 00d460f..5fc2338 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -14,7 +14,7 @@ import torch.utils.checkpoint
14from accelerate import Accelerator 14from accelerate import Accelerator
15from accelerate.logging import get_logger 15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler 18from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
20from PIL import Image 20from PIL import Image
@@ -30,6 +30,9 @@ from data.textual_inversion.csv import CSVDataModule
30logger = get_logger(__name__) 30logger = get_logger(__name__)
31 31
32 32
33torch.backends.cuda.matmul.allow_tf32 = True
34
35
33def parse_args(): 36def parse_args():
34 parser = argparse.ArgumentParser( 37 parser = argparse.ArgumentParser(
35 description="Simple example of a training script." 38 description="Simple example of a training script."
@@ -370,7 +373,6 @@ class Checkpointer:
370 unet=self.unet, 373 unet=self.unet,
371 tokenizer=self.tokenizer, 374 tokenizer=self.tokenizer,
372 scheduler=scheduler, 375 scheduler=scheduler,
373 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
374 ).to(self.accelerator.device) 376 ).to(self.accelerator.device)
375 pipeline.enable_attention_slicing() 377 pipeline.enable_attention_slicing()
376 378