summaryrefslogtreecommitdiffstats
path: root/util/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'util/ti.py')
-rw-r--r--util/ti.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/util/ti.py b/util/ti.py
new file mode 100644
index 0000000..4cc732e
--- /dev/null
+++ b/util/ti.py
@@ -0,0 +1,24 @@
1from pathlib import Path
2
3import torch
4
5from models.clip.embeddings import ManagedCLIPTextEmbeddings
6from models.clip.tokenizer import MultiCLIPTokenizer
7
8
9def load_embeddings(
10 tokenizer: MultiCLIPTokenizer,
11 embeddings: ManagedCLIPTextEmbeddings,
12 tokens: list[str],
13 token_embeddings: torch.FloatTensor,
14):
15 num_vectors = [embedding.shape[0] for embedding in token_embeddings]
16
17 token_ids = tokenizer.add_multi_tokens(tokens, num_vectors)
18
19 embeddings.resize(len(tokenizer))
20
21 for (new_id, embeds) in zip(token_ids, token_embeddings):
22 embeddings.add_embed(new_id, embeds)
23
24 return tokens, token_ids