diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 17:38:44 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 17:38:44 +0200 |
commit | f23fd5184b8ba4ec04506495f4a61726e50756f7 (patch) | |
tree | d4c5666b291316ed95437cc1c917b03ef3b679da /infer.py | |
parent | Added negative prompt support for training scripts (diff) | |
download | textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.gz textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.bz2 textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.zip |
Small perf improvements
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 5 |
1 files changed, 3 insertions, 2 deletions
@@ -16,6 +16,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
16 | from schedulers.scheduling_euler_a import EulerAScheduler | 16 | from schedulers.scheduling_euler_a import EulerAScheduler |
17 | 17 | ||
18 | 18 | ||
19 | torch.backends.cuda.matmul.allow_tf32 = True | ||
20 | |||
21 | |||
19 | default_args = { | 22 | default_args = { |
20 | "model": None, | 23 | "model": None, |
21 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
@@ -166,7 +169,6 @@ def create_pipeline(model, scheduler, dtype): | |||
166 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) | 169 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) |
167 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | 170 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) |
168 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | 171 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) |
169 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) | ||
170 | 172 | ||
171 | if scheduler == "plms": | 173 | if scheduler == "plms": |
172 | scheduler = PNDMScheduler( | 174 | scheduler = PNDMScheduler( |
@@ -191,7 +193,6 @@ def create_pipeline(model, scheduler, dtype): | |||
191 | unet=unet, | 193 | unet=unet, |
192 | tokenizer=tokenizer, | 194 | tokenizer=tokenizer, |
193 | scheduler=scheduler, | 195 | scheduler=scheduler, |
194 | feature_extractor=feature_extractor | ||
195 | ) | 196 | ) |
196 | # pipeline.enable_attention_slicing() | 197 | # pipeline.enable_attention_slicing() |
197 | pipeline.to("cuda") | 198 | pipeline.to("cuda") |