diff options
author | Volpeon <git@volpeon.ink> | 2022-12-13 23:09:25 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-13 23:09:25 +0100 |
commit | 03303d3bddba5a27a123babdf90863e27501e6f8 (patch) | |
tree | 8266c50f8e474d92ad4b42773cb8eb7730cd24c1 /common.py | |
parent | Optimized Textual Inversion training by filtering dataset by existence of add... (diff) | |
download | textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.gz textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.bz2 textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.zip |
Unified loading of TI embeddings
Diffstat (limited to 'common.py')
-rw-r--r-- | common.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/common.py b/common.py new file mode 100644 index 0000000..8d6b55d --- /dev/null +++ b/common.py | |||
@@ -0,0 +1,36 @@ | |||
1 | from pathlib import Path | ||
2 | import torch | ||
3 | |||
4 | from transformers import CLIPTextModel, CLIPTokenizer | ||
5 | |||
6 | |||
7 | def load_text_embedding(embeddings, token_id, file): | ||
8 | data = torch.load(file, map_location="cpu") | ||
9 | |||
10 | assert len(data.keys()) == 1, 'embedding data has multiple terms in it' | ||
11 | |||
12 | emb = next(iter(data.values())) | ||
13 | if len(emb.shape) == 1: | ||
14 | emb = emb.unsqueeze(0) | ||
15 | |||
16 | embeddings[token_id] = emb | ||
17 | |||
18 | |||
19 | def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): | ||
20 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
21 | return 0 | ||
22 | |||
23 | files = [file for file in embeddings_dir.iterdir() if file.is_file()] | ||
24 | |||
25 | tokens = [file.stem for file in files] | ||
26 | added = tokenizer.add_tokens(tokens) | ||
27 | token_ids = tokenizer.convert_tokens_to_ids(tokens) | ||
28 | |||
29 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
30 | |||
31 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
32 | |||
33 | for (token_id, file) in zip(token_ids, files): | ||
34 | load_text_embedding(token_embeds, token_id, file) | ||
35 | |||
36 | return added | ||