diff options
Diffstat (limited to 'util')
-rw-r--r-- | util/ti.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/util/ti.py b/util/ti.py new file mode 100644 index 0000000..4cc732e --- /dev/null +++ b/util/ti.py | |||
@@ -0,0 +1,24 @@ | |||
1 | from pathlib import Path | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | ||
6 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
7 | |||
8 | |||
9 | def load_embeddings( | ||
10 | tokenizer: MultiCLIPTokenizer, | ||
11 | embeddings: ManagedCLIPTextEmbeddings, | ||
12 | tokens: list[str], | ||
13 | token_embeddings: torch.FloatTensor, | ||
14 | ): | ||
15 | num_vectors = [embedding.shape[0] for embedding in token_embeddings] | ||
16 | |||
17 | token_ids = tokenizer.add_multi_tokens(tokens, num_vectors) | ||
18 | |||
19 | embeddings.resize(len(tokenizer)) | ||
20 | |||
21 | for (new_id, embeds) in zip(token_ids, token_embeddings): | ||
22 | embeddings.add_embed(new_id, embeds) | ||
23 | |||
24 | return tokens, token_ids | ||