diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
| commit | 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 (patch) | |
| tree | 52f10e5b7c8b1849fcd5c1210ca1cae21e2ac49e /common.py | |
| parent | Misc improvements (diff) | |
| download | textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.gz textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.bz2 textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.zip | |
Added multi-vector embeddings
Diffstat (limited to 'common.py')
| -rw-r--r-- | common.py | 38 |
1 files changed, 13 insertions, 25 deletions
| @@ -1,9 +1,10 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | 3 | ||
| 4 | import torch | 4 | from models.clip.embeddings import ManagedCLIPTextEmbeddings |
| 5 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 5 | 6 | ||
| 6 | from transformers import CLIPTextModel, CLIPTokenizer | 7 | from safetensors import safe_open |
| 7 | 8 | ||
| 8 | 9 | ||
| 9 | def load_config(filename): | 10 | def load_config(filename): |
| @@ -18,33 +19,20 @@ def load_config(filename): | |||
| 18 | return args | 19 | return args |
| 19 | 20 | ||
| 20 | 21 | ||
| 21 | def load_text_embedding(embeddings, token_id, file): | 22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): |
| 22 | data = torch.load(file, map_location="cpu") | ||
| 23 | |||
| 24 | assert len(data.keys()) == 1, 'embedding data has multiple terms in it' | ||
| 25 | |||
| 26 | emb = next(iter(data.values())) | ||
| 27 | if len(emb.shape) == 1: | ||
| 28 | emb = emb.unsqueeze(0) | ||
| 29 | |||
| 30 | embeddings[token_id] = emb | ||
| 31 | |||
| 32 | |||
| 33 | def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): | ||
| 34 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 35 | return [] | 24 | return [] |
| 36 | 25 | ||
| 37 | files = [file for file in embeddings_dir.iterdir() if file.is_file()] | 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] |
| 38 | 27 | tokens = [filename.stem for filename in filenames] | |
| 39 | tokens = [file.stem for file in files] | ||
| 40 | added = tokenizer.add_tokens(tokens) | ||
| 41 | token_ids = tokenizer.convert_tokens_to_ids(tokens) | ||
| 42 | |||
| 43 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 44 | 28 | ||
| 45 | token_embeds = text_encoder.get_input_embeddings().weight.data | 29 | for filename in embeddings_dir.iterdir(): |
| 30 | if filename.is_file(): | ||
| 31 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
| 32 | embed = file.get_tensor("embed") | ||
| 46 | 33 | ||
| 47 | for (token_id, file) in zip(token_ids, files): | 34 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) |
| 48 | load_text_embedding(token_embeds, token_id, file) | 35 | embeddings.add_embed(added.placeholder_id) |
| 36 | embeddings.add_embed(added.multi_ids, embed) | ||
| 49 | 37 | ||
| 50 | return tokens | 38 | return tokens |
