summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-06 17:15:22 +0200
committerVolpeon <git@volpeon.ink>2022-10-06 17:15:22 +0200
commit49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 (patch)
tree8bd8fe59b2a5b60c2f6e7e1b48b53be7fbf1e155 /infer.py
parentInference: Add support for embeddings (diff)
downloadtextual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.gz
textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.bz2
textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.zip
Update
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/infer.py b/infer.py
index 3487e5a..34e570a 100644
--- a/infer.py
+++ b/infer.py
@@ -171,6 +171,18 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir):
171 embeddings_dir = Path(embeddings_dir) 171 embeddings_dir = Path(embeddings_dir)
172 embeddings_dir.mkdir(parents=True, exist_ok=True) 172 embeddings_dir.mkdir(parents=True, exist_ok=True)
173 173
174 for file in embeddings_dir.iterdir():
175 placeholder_token = file.stem
176
177 num_added_tokens = tokenizer.add_tokens(placeholder_token)
178 if num_added_tokens == 0:
179 raise ValueError(
180 f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
181 " `placeholder_token` that is not already in the tokenizer."
182 )
183
184 text_encoder.resize_token_embeddings(len(tokenizer))
185
174 token_embeds = text_encoder.get_input_embeddings().weight.data 186 token_embeds = text_encoder.get_input_embeddings().weight.data
175 187
176 for file in embeddings_dir.iterdir(): 188 for file in embeddings_dir.iterdir():
@@ -187,6 +199,8 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir):
187 199
188 token_embeds[placeholder_token_id] = emb 200 token_embeds[placeholder_token_id] = emb
189 201
202 print(f"Loaded embedding: {placeholder_token}")
203
190 204
191def create_pipeline(model, scheduler, embeddings_dir, dtype): 205def create_pipeline(model, scheduler, embeddings_dir, dtype):
192 print("Loading Stable Diffusion pipeline...") 206 print("Loading Stable Diffusion pipeline...")