summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
commitf23fd5184b8ba4ec04506495f4a61726e50756f7 (patch)
treed4c5666b291316ed95437cc1c917b03ef3b679da /infer.py
parentAdded negative prompt support for training scripts (diff)
downloadtextual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.gz
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.bz2
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.zip
Small perf improvements
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/infer.py b/infer.py
index b15b17f..3dc0f32 100644
--- a/infer.py
+++ b/infer.py
@@ -16,6 +16,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from schedulers.scheduling_euler_a import EulerAScheduler 16from schedulers.scheduling_euler_a import EulerAScheduler
17 17
18 18
19torch.backends.cuda.matmul.allow_tf32 = True
20
21
19default_args = { 22default_args = {
20 "model": None, 23 "model": None,
21 "scheduler": "euler_a", 24 "scheduler": "euler_a",
@@ -166,7 +169,6 @@ def create_pipeline(model, scheduler, dtype):
166 text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) 169 text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype)
167 vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) 170 vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype)
168 unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) 171 unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype)
169 feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype)
170 172
171 if scheduler == "plms": 173 if scheduler == "plms":
172 scheduler = PNDMScheduler( 174 scheduler = PNDMScheduler(
@@ -191,7 +193,6 @@ def create_pipeline(model, scheduler, dtype):
191 unet=unet, 193 unet=unet,
192 tokenizer=tokenizer, 194 tokenizer=tokenizer,
193 scheduler=scheduler, 195 scheduler=scheduler,
194 feature_extractor=feature_extractor
195 ) 196 )
196 # pipeline.enable_attention_slicing() 197 # pipeline.enable_attention_slicing()
197 pipeline.to("cuda") 198 pipeline.to("cuda")