diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 14 |
1 files changed, 14 insertions, 0 deletions
@@ -171,6 +171,18 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir): | |||
171 | embeddings_dir = Path(embeddings_dir) | 171 | embeddings_dir = Path(embeddings_dir) |
172 | embeddings_dir.mkdir(parents=True, exist_ok=True) | 172 | embeddings_dir.mkdir(parents=True, exist_ok=True) |
173 | 173 | ||
174 | for file in embeddings_dir.iterdir(): | ||
175 | placeholder_token = file.stem | ||
176 | |||
177 | num_added_tokens = tokenizer.add_tokens(placeholder_token) | ||
178 | if num_added_tokens == 0: | ||
179 | raise ValueError( | ||
180 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | ||
181 | " `placeholder_token` that is not already in the tokenizer." | ||
182 | ) | ||
183 | |||
184 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
185 | |||
174 | token_embeds = text_encoder.get_input_embeddings().weight.data | 186 | token_embeds = text_encoder.get_input_embeddings().weight.data |
175 | 187 | ||
176 | for file in embeddings_dir.iterdir(): | 188 | for file in embeddings_dir.iterdir(): |
@@ -187,6 +199,8 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir): | |||
187 | 199 | ||
188 | token_embeds[placeholder_token_id] = emb | 200 | token_embeds[placeholder_token_id] = emb |
189 | 201 | ||
202 | print(f"Loaded embedding: {placeholder_token}") | ||
203 | |||
190 | 204 | ||
191 | def create_pipeline(model, scheduler, embeddings_dir, dtype): | 205 | def create_pipeline(model, scheduler, embeddings_dir, dtype): |
192 | print("Loading Stable Diffusion pipeline...") | 206 | print("Loading Stable Diffusion pipeline...") |