diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
| commit | 6d46bf79bd7710cea799fbfe27c12d06d12cd53f (patch) | |
| tree | 6c65817b9351453bfb5366f7010f8d87659c0dd0 /util | |
| parent | Fix cycle loop (diff) | |
| download | textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.gz textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.bz2 textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.zip | |
Update
Diffstat (limited to 'util')
| -rw-r--r-- | util/ti.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/util/ti.py b/util/ti.py new file mode 100644 index 0000000..4cc732e --- /dev/null +++ b/util/ti.py | |||
| @@ -0,0 +1,24 @@ | |||
| 1 | from pathlib import Path | ||
| 2 | |||
| 3 | import torch | ||
| 4 | |||
| 5 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | ||
| 6 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 7 | |||
| 8 | |||
| 9 | def load_embeddings( | ||
| 10 | tokenizer: MultiCLIPTokenizer, | ||
| 11 | embeddings: ManagedCLIPTextEmbeddings, | ||
| 12 | tokens: list[str], | ||
| 13 | token_embeddings: torch.FloatTensor, | ||
| 14 | ): | ||
| 15 | num_vectors = [embedding.shape[0] for embedding in token_embeddings] | ||
| 16 | |||
| 17 | token_ids = tokenizer.add_multi_tokens(tokens, num_vectors) | ||
| 18 | |||
| 19 | embeddings.resize(len(tokenizer)) | ||
| 20 | |||
| 21 | for (new_id, embeds) in zip(token_ids, token_embeddings): | ||
| 22 | embeddings.add_embed(new_id, embeds) | ||
| 23 | |||
| 24 | return tokens, token_ids | ||
