summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/infer.py b/infer.py
index 1fd11e2..efeb24d 100644
--- a/infer.py
+++ b/infer.py
@@ -181,7 +181,7 @@ def save_args(basepath, args, extra={}):
181 json.dump(info, f, indent=4) 181 json.dump(info, f, indent=4)
182 182
183 183
184def create_pipeline(model, ti_embeddings_dir, dtype): 184def create_pipeline(model, embeddings_dir, dtype):
185 print("Loading Stable Diffusion pipeline...") 185 print("Loading Stable Diffusion pipeline...")
186 186
187 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 187 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
@@ -190,7 +190,8 @@ def create_pipeline(model, ti_embeddings_dir, dtype):
190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
192 192
193 load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir)) 193 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
194 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
194 195
195 pipeline = VlpnStableDiffusion( 196 pipeline = VlpnStableDiffusion(
196 text_encoder=text_encoder, 197 text_encoder=text_encoder,