diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-06 17:15:22 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-06 17:15:22 +0200 |
| commit | 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 (patch) | |
| tree | 8bd8fe59b2a5b60c2f6e7e1b48b53be7fbf1e155 /infer.py | |
| parent | Inference: Add support for embeddings (diff) | |
| download | textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.gz textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.bz2 textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.zip | |
Update
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 14 |
1 files changed, 14 insertions, 0 deletions
| @@ -171,6 +171,18 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir): | |||
| 171 | embeddings_dir = Path(embeddings_dir) | 171 | embeddings_dir = Path(embeddings_dir) |
| 172 | embeddings_dir.mkdir(parents=True, exist_ok=True) | 172 | embeddings_dir.mkdir(parents=True, exist_ok=True) |
| 173 | 173 | ||
| 174 | for file in embeddings_dir.iterdir(): | ||
| 175 | placeholder_token = file.stem | ||
| 176 | |||
| 177 | num_added_tokens = tokenizer.add_tokens(placeholder_token) | ||
| 178 | if num_added_tokens == 0: | ||
| 179 | raise ValueError( | ||
| 180 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | ||
| 181 | " `placeholder_token` that is not already in the tokenizer." | ||
| 182 | ) | ||
| 183 | |||
| 184 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 185 | |||
| 174 | token_embeds = text_encoder.get_input_embeddings().weight.data | 186 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 175 | 187 | ||
| 176 | for file in embeddings_dir.iterdir(): | 188 | for file in embeddings_dir.iterdir(): |
| @@ -187,6 +199,8 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir): | |||
| 187 | 199 | ||
| 188 | token_embeds[placeholder_token_id] = emb | 200 | token_embeds[placeholder_token_id] = emb |
| 189 | 201 | ||
| 202 | print(f"Loaded embedding: {placeholder_token}") | ||
| 203 | |||
| 190 | 204 | ||
| 191 | def create_pipeline(model, scheduler, embeddings_dir, dtype): | 205 | def create_pipeline(model, scheduler, embeddings_dir, dtype): |
| 192 | print("Loading Stable Diffusion pipeline...") | 206 | print("Loading Stable Diffusion pipeline...") |
