From 6d46bf79bd7710cea799fbfe27c12d06d12cd53f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Apr 2023 07:47:59 +0200 Subject: Update --- util/ti.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 util/ti.py (limited to 'util') diff --git a/util/ti.py b/util/ti.py new file mode 100644 index 0000000..4cc732e --- /dev/null +++ b/util/ti.py @@ -0,0 +1,24 @@ +from pathlib import Path + +import torch + +from models.clip.embeddings import ManagedCLIPTextEmbeddings +from models.clip.tokenizer import MultiCLIPTokenizer + + +def load_embeddings( + tokenizer: MultiCLIPTokenizer, + embeddings: ManagedCLIPTextEmbeddings, + tokens: list[str], + token_embeddings: torch.FloatTensor, +): + num_vectors = [embedding.shape[0] for embedding in token_embeddings] + + token_ids = tokenizer.add_multi_tokens(tokens, num_vectors) + + embeddings.resize(len(tokenizer)) + + for (new_id, embeds) in zip(token_ids, token_embeddings): + embeddings.add_embed(new_id, embeds) + + return tokens, token_ids -- cgit v1.2.3-70-g09d2