From 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 12:58:54 +0100 Subject: Added multi-vector embeddings --- infer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index ae0b4da..4bcaff5 100644 --- a/infer.py +++ b/infer.py @@ -8,6 +8,7 @@ from pathlib import Path import torch import json from PIL import Image +from slugify import slugify from diffusers import ( AutoencoderKL, UNet2DConditionModel, @@ -20,11 +21,12 @@ from diffusers import ( KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler ) -from transformers import CLIPTextModel, CLIPTokenizer -from slugify import slugify +from transformers import CLIPTextModel +from models.clip.embeddings import patch_managed_embeddings +from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from common import load_text_embeddings, load_config +from common import load_config, load_embeddings_from_dir torch.backends.cuda.matmul.allow_tf32 = True @@ -183,13 +185,15 @@ def save_args(basepath, args, extra={}): def create_pipeline(model, embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") - tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) + tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) - added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) + embeddings = patch_managed_embeddings(text_encoder) + added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir)) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") pipeline = VlpnStableDiffusion( -- cgit v1.2.3-54-g00ecf