summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 09:25:13 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 09:25:13 +0100
commite2d3a62bce63fcde940395a1c5618c4eb43385a9 (patch)
tree574f7a794feab13e1cf0ed18522a66d4737b6db3 /infer.py
parentUnified training script structure (diff)
downloadtextual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.gz
textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.bz2
textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.zip
Cleanup
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py19
1 files changed, 4 insertions, 15 deletions
diff --git a/infer.py b/infer.py
index 2b07b21..36b5a2c 100644
--- a/infer.py
+++ b/infer.py
@@ -214,21 +214,10 @@ def load_embeddings(pipeline, embeddings_dir):
214def create_pipeline(model, dtype): 214def create_pipeline(model, dtype):
215 print("Loading Stable Diffusion pipeline...") 215 print("Loading Stable Diffusion pipeline...")
216 216
217 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 217 pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype)
218 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) 218
219 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 219 patch_managed_embeddings(pipeline.text_encoder)
220 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 220
221 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
222
223 patch_managed_embeddings(text_encoder)
224
225 pipeline = VlpnStableDiffusion(
226 text_encoder=text_encoder,
227 vae=vae,
228 unet=unet,
229 tokenizer=tokenizer,
230 scheduler=scheduler,
231 )
232 pipeline.enable_xformers_memory_efficient_attention() 221 pipeline.enable_xformers_memory_efficient_attention()
233 pipeline.enable_vae_slicing() 222 pipeline.enable_vae_slicing()
234 pipeline.to("cuda") 223 pipeline.to("cuda")