diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
commit | e2d3a62bce63fcde940395a1c5618c4eb43385a9 (patch) | |
tree | 574f7a794feab13e1cf0ed18522a66d4737b6db3 /infer.py | |
parent | Unified training script structure (diff) | |
download | textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.gz textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.bz2 textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.zip |
Cleanup
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 19 |
1 files changed, 4 insertions, 15 deletions
@@ -214,21 +214,10 @@ def load_embeddings(pipeline, embeddings_dir): | |||
214 | def create_pipeline(model, dtype): | 214 | def create_pipeline(model, dtype): |
215 | print("Loading Stable Diffusion pipeline...") | 215 | print("Loading Stable Diffusion pipeline...") |
216 | 216 | ||
217 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 217 | pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype) |
218 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 218 | |
219 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 219 | patch_managed_embeddings(pipeline.text_encoder) |
220 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 220 | |
221 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | ||
222 | |||
223 | patch_managed_embeddings(text_encoder) | ||
224 | |||
225 | pipeline = VlpnStableDiffusion( | ||
226 | text_encoder=text_encoder, | ||
227 | vae=vae, | ||
228 | unet=unet, | ||
229 | tokenizer=tokenizer, | ||
230 | scheduler=scheduler, | ||
231 | ) | ||
232 | pipeline.enable_xformers_memory_efficient_attention() | 221 | pipeline.enable_xformers_memory_efficient_attention() |
233 | pipeline.enable_vae_slicing() | 222 | pipeline.enable_vae_slicing() |
234 | pipeline.to("cuda") | 223 | pipeline.to("cuda") |