diff options
author | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
commit | 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 (patch) | |
tree | 52f10e5b7c8b1849fcd5c1210ca1cae21e2ac49e /infer.py | |
parent | Misc improvements (diff) | |
download | textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.gz textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.bz2 textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.zip |
Added multi-vector embeddings
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 14 |
1 files changed, 9 insertions, 5 deletions
@@ -8,6 +8,7 @@ from pathlib import Path | |||
8 | import torch | 8 | import torch |
9 | import json | 9 | import json |
10 | from PIL import Image | 10 | from PIL import Image |
11 | from slugify import slugify | ||
11 | from diffusers import ( | 12 | from diffusers import ( |
12 | AutoencoderKL, | 13 | AutoencoderKL, |
13 | UNet2DConditionModel, | 14 | UNet2DConditionModel, |
@@ -20,11 +21,12 @@ from diffusers import ( | |||
20 | KDPM2DiscreteScheduler, | 21 | KDPM2DiscreteScheduler, |
21 | KDPM2AncestralDiscreteScheduler | 22 | KDPM2AncestralDiscreteScheduler |
22 | ) | 23 | ) |
23 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel |
24 | from slugify import slugify | ||
25 | 25 | ||
26 | from models.clip.embeddings import patch_managed_embeddings | ||
27 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 28 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from common import load_text_embeddings, load_config | 29 | from common import load_config, load_embeddings_from_dir |
28 | 30 | ||
29 | 31 | ||
30 | torch.backends.cuda.matmul.allow_tf32 = True | 32 | torch.backends.cuda.matmul.allow_tf32 = True |
@@ -183,13 +185,15 @@ def save_args(basepath, args, extra={}): | |||
183 | def create_pipeline(model, embeddings_dir, dtype): | 185 | def create_pipeline(model, embeddings_dir, dtype): |
184 | print("Loading Stable Diffusion pipeline...") | 186 | print("Loading Stable Diffusion pipeline...") |
185 | 187 | ||
186 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 188 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
187 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 189 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) |
188 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 190 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) |
189 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 191 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
190 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 192 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
191 | 193 | ||
192 | added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) | 194 | embeddings = patch_managed_embeddings(text_encoder) |
195 | added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir)) | ||
196 | |||
193 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | 197 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") |
194 | 198 | ||
195 | pipeline = VlpnStableDiffusion( | 199 | pipeline = VlpnStableDiffusion( |