diff options
Diffstat (limited to 'util')
| -rw-r--r-- | util/ti.py | 24 | 
1 files changed, 24 insertions, 0 deletions
diff --git a/util/ti.py b/util/ti.py new file mode 100644 index 0000000..4cc732e --- /dev/null +++ b/util/ti.py  | |||
| @@ -0,0 +1,24 @@ | |||
| 1 | from pathlib import Path | ||
| 2 | |||
| 3 | import torch | ||
| 4 | |||
| 5 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | ||
| 6 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 7 | |||
| 8 | |||
| 9 | def load_embeddings( | ||
| 10 | tokenizer: MultiCLIPTokenizer, | ||
| 11 | embeddings: ManagedCLIPTextEmbeddings, | ||
| 12 | tokens: list[str], | ||
| 13 | token_embeddings: torch.FloatTensor, | ||
| 14 | ): | ||
| 15 | num_vectors = [embedding.shape[0] for embedding in token_embeddings] | ||
| 16 | |||
| 17 | token_ids = tokenizer.add_multi_tokens(tokens, num_vectors) | ||
| 18 | |||
| 19 | embeddings.resize(len(tokenizer)) | ||
| 20 | |||
| 21 | for (new_id, embeds) in zip(token_ids, token_embeddings): | ||
| 22 | embeddings.add_embed(new_id, embeds) | ||
| 23 | |||
| 24 | return tokens, token_ids | ||
