summaryrefslogtreecommitdiffstats
path: root/common.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 12:58:54 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 12:58:54 +0100
commit6b58e9de249e872bd2d83e5916e6c633f52cfbb8 (patch)
tree52f10e5b7c8b1849fcd5c1210ca1cae21e2ac49e /common.py
parentMisc improvements (diff)
downloadtextual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.gz
textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.bz2
textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.zip
Added multi-vector embeddings
Diffstat (limited to 'common.py')
-rw-r--r--common.py38
1 files changed, 13 insertions, 25 deletions
diff --git a/common.py b/common.py
index f369475..e8d3ac1 100644
--- a/common.py
+++ b/common.py
@@ -1,9 +1,10 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3 3
4import torch 4from models.clip.embeddings import ManagedCLIPTextEmbeddings
5from models.clip.tokenizer import MultiCLIPTokenizer
5 6
6from transformers import CLIPTextModel, CLIPTokenizer 7from safetensors import safe_open
7 8
8 9
9def load_config(filename): 10def load_config(filename):
@@ -18,33 +19,20 @@ def load_config(filename):
18 return args 19 return args
19 20
20 21
21def load_text_embedding(embeddings, token_id, file): 22def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path):
22 data = torch.load(file, map_location="cpu")
23
24 assert len(data.keys()) == 1, 'embedding data has multiple terms in it'
25
26 emb = next(iter(data.values()))
27 if len(emb.shape) == 1:
28 emb = emb.unsqueeze(0)
29
30 embeddings[token_id] = emb
31
32
33def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path):
34 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 23 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
35 return [] 24 return []
36 25
37 files = [file for file in embeddings_dir.iterdir() if file.is_file()] 26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()]
38 27 tokens = [filename.stem for filename in filenames]
39 tokens = [file.stem for file in files]
40 added = tokenizer.add_tokens(tokens)
41 token_ids = tokenizer.convert_tokens_to_ids(tokens)
42
43 text_encoder.resize_token_embeddings(len(tokenizer))
44 28
45 token_embeds = text_encoder.get_input_embeddings().weight.data 29 for filename in embeddings_dir.iterdir():
30 if filename.is_file():
31 with safe_open(filename, framework="pt", device="cpu") as file:
32 embed = file.get_tensor("embed")
46 33
47 for (token_id, file) in zip(token_ids, files): 34 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
48 load_text_embedding(token_embeds, token_id, file) 35 embeddings.add_embed(added.placeholder_id)
36 embeddings.add_embed(added.multi_ids, embed)
49 37
50 return tokens 38 return tokens