diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-03 17:38:44 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-03 17:38:44 +0200 |
| commit | f23fd5184b8ba4ec04506495f4a61726e50756f7 (patch) | |
| tree | d4c5666b291316ed95437cc1c917b03ef3b679da /infer.py | |
| parent | Added negative prompt support for training scripts (diff) | |
| download | textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.gz textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.bz2 textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.zip | |
Small perf improvements
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 5 |
1 files changed, 3 insertions, 2 deletions
| @@ -16,6 +16,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 16 | from schedulers.scheduling_euler_a import EulerAScheduler | 16 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | torch.backends.cuda.matmul.allow_tf32 = True | ||
| 20 | |||
| 21 | |||
| 19 | default_args = { | 22 | default_args = { |
| 20 | "model": None, | 23 | "model": None, |
| 21 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
| @@ -166,7 +169,6 @@ def create_pipeline(model, scheduler, dtype): | |||
| 166 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) | 169 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) |
| 167 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | 170 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) |
| 168 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | 171 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) |
| 169 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) | ||
| 170 | 172 | ||
| 171 | if scheduler == "plms": | 173 | if scheduler == "plms": |
| 172 | scheduler = PNDMScheduler( | 174 | scheduler = PNDMScheduler( |
| @@ -191,7 +193,6 @@ def create_pipeline(model, scheduler, dtype): | |||
| 191 | unet=unet, | 193 | unet=unet, |
| 192 | tokenizer=tokenizer, | 194 | tokenizer=tokenizer, |
| 193 | scheduler=scheduler, | 195 | scheduler=scheduler, |
| 194 | feature_extractor=feature_extractor | ||
| 195 | ) | 196 | ) |
| 196 | # pipeline.enable_attention_slicing() | 197 | # pipeline.enable_attention_slicing() |
| 197 | pipeline.to("cuda") | 198 | pipeline.to("cuda") |
