summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/infer.py b/infer.py
index b29b136..507d0cf 100644
--- a/infer.py
+++ b/infer.py
@@ -28,7 +28,7 @@ from transformers import CLIPTextModel
28from models.clip.embeddings import patch_managed_embeddings 28from models.clip.embeddings import patch_managed_embeddings
29from models.clip.tokenizer import MultiCLIPTokenizer 29from models.clip.tokenizer import MultiCLIPTokenizer
30from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 30from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
31from common import load_config, load_embeddings_from_dir 31from util import load_config, load_embeddings_from_dir
32 32
33 33
34torch.backends.cuda.matmul.allow_tf32 = True 34torch.backends.cuda.matmul.allow_tf32 = True
@@ -192,12 +192,12 @@ def save_args(basepath, args, extra={}):
192 192
193 193
194def load_embeddings(pipeline, embeddings_dir): 194def load_embeddings(pipeline, embeddings_dir):
195 added_tokens = load_embeddings_from_dir( 195 added_tokens, added_ids = load_embeddings_from_dir(
196 pipeline.tokenizer, 196 pipeline.tokenizer,
197 pipeline.text_encoder.text_model.embeddings, 197 pipeline.text_encoder.text_model.embeddings,
198 Path(embeddings_dir) 198 Path(embeddings_dir)
199 ) 199 )
200 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") 200 print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}")
201 201
202 202
203def create_pipeline(model, dtype): 203def create_pipeline(model, dtype):