From e2d3a62bce63fcde940395a1c5618c4eb43385a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:25:13 +0100 Subject: Cleanup --- infer.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 2b07b21..36b5a2c 100644 --- a/infer.py +++ b/infer.py @@ -214,21 +214,10 @@ def load_embeddings(pipeline, embeddings_dir): def create_pipeline(model, dtype): print("Loading Stable Diffusion pipeline...") - tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) - text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) - vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) - unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) - scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) - - patch_managed_embeddings(text_encoder) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - ) + pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype) + + patch_managed_embeddings(pipeline.text_encoder) + pipeline.enable_xformers_memory_efficient_attention() pipeline.enable_vae_slicing() pipeline.to("cuda") -- cgit v1.2.3-54-g00ecf