From 83808fe00ac891ad2f625388d144c318b2cb5bfe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 21:53:07 +0100 Subject: WIP: Modularization ("free(): invalid pointer" my ass) --- infer.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 36b5a2c..2b07b21 100644 --- a/infer.py +++ b/infer.py @@ -214,10 +214,21 @@ def load_embeddings(pipeline, embeddings_dir): def create_pipeline(model, dtype): print("Loading Stable Diffusion pipeline...") - pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype) - - patch_managed_embeddings(pipeline.text_encoder) - + tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) + text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) + vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) + unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) + scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) + + patch_managed_embeddings(text_encoder) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ) pipeline.enable_xformers_memory_efficient_attention() pipeline.enable_vae_slicing() pipeline.to("cuda") -- cgit v1.2.3-54-g00ecf