summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 21:53:07 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 21:53:07 +0100
commit83808fe00ac891ad2f625388d144c318b2cb5bfe (patch)
treeb7ca19d27f90be6f02b14f4a39c62fc7250041a2 /infer.py
parentTI: Prepare UNet with Accelerate as well (diff)
downloadtextual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.gz
textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.bz2
textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.zip
WIP: Modularization ("free(): invalid pointer" my ass)
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/infer.py b/infer.py
index 36b5a2c..2b07b21 100644
--- a/infer.py
+++ b/infer.py
@@ -214,10 +214,21 @@ def load_embeddings(pipeline, embeddings_dir):
214def create_pipeline(model, dtype): 214def create_pipeline(model, dtype):
215 print("Loading Stable Diffusion pipeline...") 215 print("Loading Stable Diffusion pipeline...")
216 216
217 pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype) 217 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
218 218 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype)
219 patch_managed_embeddings(pipeline.text_encoder) 219 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype)
220 220 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
221 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
222
223 patch_managed_embeddings(text_encoder)
224
225 pipeline = VlpnStableDiffusion(
226 text_encoder=text_encoder,
227 vae=vae,
228 unet=unet,
229 tokenizer=tokenizer,
230 scheduler=scheduler,
231 )
221 pipeline.enable_xformers_memory_efficient_attention() 232 pipeline.enable_xformers_memory_efficient_attention()
222 pipeline.enable_vae_slicing() 233 pipeline.enable_vae_slicing()
223 pipeline.to("cuda") 234 pipeline.to("cuda")