summaryrefslogtreecommitdiffstats
path: root/common.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-13 23:09:25 +0100
committerVolpeon <git@volpeon.ink>2022-12-13 23:09:25 +0100
commit03303d3bddba5a27a123babdf90863e27501e6f8 (patch)
tree8266c50f8e474d92ad4b42773cb8eb7730cd24c1 /common.py
parentOptimized Textual Inversion training by filtering dataset by existence of add... (diff)
downloadtextual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.gz
textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.bz2
textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.zip
Unified loading of TI embeddings
Diffstat (limited to 'common.py')
-rw-r--r--common.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/common.py b/common.py
new file mode 100644
index 0000000..8d6b55d
--- /dev/null
+++ b/common.py
@@ -0,0 +1,36 @@
1from pathlib import Path
2import torch
3
4from transformers import CLIPTextModel, CLIPTokenizer
5
6
7def load_text_embedding(embeddings, token_id, file):
8 data = torch.load(file, map_location="cpu")
9
10 assert len(data.keys()) == 1, 'embedding data has multiple terms in it'
11
12 emb = next(iter(data.values()))
13 if len(emb.shape) == 1:
14 emb = emb.unsqueeze(0)
15
16 embeddings[token_id] = emb
17
18
19def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path):
20 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
21 return 0
22
23 files = [file for file in embeddings_dir.iterdir() if file.is_file()]
24
25 tokens = [file.stem for file in files]
26 added = tokenizer.add_tokens(tokens)
27 token_ids = tokenizer.convert_tokens_to_ids(tokens)
28
29 text_encoder.resize_token_embeddings(len(tokenizer))
30
31 token_embeds = text_encoder.get_input_embeddings().weight.data
32
33 for (token_id, file) in zip(token_ids, files):
34 load_text_embedding(token_embeds, token_id, file)
35
36 return added