summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 12:58:54 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 12:58:54 +0100
commit6b58e9de249e872bd2d83e5916e6c633f52cfbb8 (patch)
tree52f10e5b7c8b1849fcd5c1210ca1cae21e2ac49e /infer.py
parentMisc improvements (diff)
downloadtextual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.gz
textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.bz2
textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.zip
Added multi-vector embeddings
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/infer.py b/infer.py
index ae0b4da..4bcaff5 100644
--- a/infer.py
+++ b/infer.py
@@ -8,6 +8,7 @@ from pathlib import Path
8import torch 8import torch
9import json 9import json
10from PIL import Image 10from PIL import Image
11from slugify import slugify
11from diffusers import ( 12from diffusers import (
12 AutoencoderKL, 13 AutoencoderKL,
13 UNet2DConditionModel, 14 UNet2DConditionModel,
@@ -20,11 +21,12 @@ from diffusers import (
20 KDPM2DiscreteScheduler, 21 KDPM2DiscreteScheduler,
21 KDPM2AncestralDiscreteScheduler 22 KDPM2AncestralDiscreteScheduler
22) 23)
23from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel
24from slugify import slugify
25 25
26from models.clip.embeddings import patch_managed_embeddings
27from models.clip.tokenizer import MultiCLIPTokenizer
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 28from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from common import load_text_embeddings, load_config 29from common import load_config, load_embeddings_from_dir
28 30
29 31
30torch.backends.cuda.matmul.allow_tf32 = True 32torch.backends.cuda.matmul.allow_tf32 = True
@@ -183,13 +185,15 @@ def save_args(basepath, args, extra={}):
183def create_pipeline(model, embeddings_dir, dtype): 185def create_pipeline(model, embeddings_dir, dtype):
184 print("Loading Stable Diffusion pipeline...") 186 print("Loading Stable Diffusion pipeline...")
185 187
186 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 188 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
187 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) 189 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype)
188 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 190 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype)
189 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 191 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
190 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 192 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
191 193
192 added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) 194 embeddings = patch_managed_embeddings(text_encoder)
195 added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir))
196
193 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") 197 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
194 198
195 pipeline = VlpnStableDiffusion( 199 pipeline = VlpnStableDiffusion(