summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py19
1 files changed, 4 insertions, 15 deletions
diff --git a/infer.py b/infer.py
index 2b07b21..36b5a2c 100644
--- a/infer.py
+++ b/infer.py
@@ -214,21 +214,10 @@ def load_embeddings(pipeline, embeddings_dir):
214def create_pipeline(model, dtype): 214def create_pipeline(model, dtype):
215 print("Loading Stable Diffusion pipeline...") 215 print("Loading Stable Diffusion pipeline...")
216 216
217 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 217 pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype)
218 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) 218
219 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 219 patch_managed_embeddings(pipeline.text_encoder)
220 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 220
221 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
222
223 patch_managed_embeddings(text_encoder)
224
225 pipeline = VlpnStableDiffusion(
226 text_encoder=text_encoder,
227 vae=vae,
228 unet=unet,
229 tokenizer=tokenizer,
230 scheduler=scheduler,
231 )
232 pipeline.enable_xformers_memory_efficient_attention() 221 pipeline.enable_xformers_memory_efficient_attention()
233 pipeline.enable_vae_slicing() 222 pipeline.enable_vae_slicing()
234 pipeline.to("cuda") 223 pipeline.to("cuda")