summaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-27 07:47:59 +0200
committerVolpeon <git@volpeon.ink>2023-04-27 07:47:59 +0200
commit6d46bf79bd7710cea799fbfe27c12d06d12cd53f (patch)
tree6c65817b9351453bfb5366f7010f8d87659c0dd0 /util
parentFix cycle loop (diff)
downloadtextual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.gz
textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.bz2
textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.zip
Update
Diffstat (limited to 'util')
-rw-r--r--util/ti.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/util/ti.py b/util/ti.py
new file mode 100644
index 0000000..4cc732e
--- /dev/null
+++ b/util/ti.py
@@ -0,0 +1,24 @@
1from pathlib import Path
2
3import torch
4
5from models.clip.embeddings import ManagedCLIPTextEmbeddings
6from models.clip.tokenizer import MultiCLIPTokenizer
7
8
9def load_embeddings(
10 tokenizer: MultiCLIPTokenizer,
11 embeddings: ManagedCLIPTextEmbeddings,
12 tokens: list[str],
13 token_embeddings: torch.FloatTensor,
14):
15 num_vectors = [embedding.shape[0] for embedding in token_embeddings]
16
17 token_ids = tokenizer.add_multi_tokens(tokens, num_vectors)
18
19 embeddings.resize(len(tokenizer))
20
21 for (new_id, embeds) in zip(token_ids, token_embeddings):
22 embeddings.add_embed(new_id, embeds)
23
24 return tokens, token_ids