diff options
Diffstat (limited to 'common.py')
-rw-r--r-- | common.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/common.py b/common.py new file mode 100644 index 0000000..8d6b55d --- /dev/null +++ b/common.py | |||
@@ -0,0 +1,36 @@ | |||
1 | from pathlib import Path | ||
2 | import torch | ||
3 | |||
4 | from transformers import CLIPTextModel, CLIPTokenizer | ||
5 | |||
6 | |||
7 | def load_text_embedding(embeddings, token_id, file): | ||
8 | data = torch.load(file, map_location="cpu") | ||
9 | |||
10 | assert len(data.keys()) == 1, 'embedding data has multiple terms in it' | ||
11 | |||
12 | emb = next(iter(data.values())) | ||
13 | if len(emb.shape) == 1: | ||
14 | emb = emb.unsqueeze(0) | ||
15 | |||
16 | embeddings[token_id] = emb | ||
17 | |||
18 | |||
19 | def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): | ||
20 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
21 | return 0 | ||
22 | |||
23 | files = [file for file in embeddings_dir.iterdir() if file.is_file()] | ||
24 | |||
25 | tokens = [file.stem for file in files] | ||
26 | added = tokenizer.add_tokens(tokens) | ||
27 | token_ids = tokenizer.convert_tokens_to_ids(tokens) | ||
28 | |||
29 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
30 | |||
31 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
32 | |||
33 | for (token_id, file) in zip(token_ids, files): | ||
34 | load_text_embedding(token_embeds, token_id, file) | ||
35 | |||
36 | return added | ||