diff options
author | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
commit | 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 (patch) | |
tree | 52f10e5b7c8b1849fcd5c1210ca1cae21e2ac49e /models/clip/embeddings.py | |
parent | Misc improvements (diff) | |
download | textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.gz textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.bz2 textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.zip |
Added multi-vector embeddings
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r-- | models/clip/embeddings.py | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py new file mode 100644 index 0000000..7d63ffb --- /dev/null +++ b/models/clip/embeddings.py | |||
@@ -0,0 +1,109 @@ | |||
1 | from typing import Union, Optional | ||
2 | from pathlib import Path | ||
3 | |||
4 | import torch | ||
5 | import torch.nn as nn | ||
6 | |||
7 | from safetensors import safe_open | ||
8 | from safetensors.torch import save_file | ||
9 | |||
10 | from transformers import CLIPTextModel | ||
11 | from transformers.models.clip import CLIPTextConfig | ||
12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
13 | |||
14 | |||
15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: | ||
16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() | ||
17 | |||
18 | new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) | ||
19 | new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) | ||
20 | new_embedding.weight.data.zero_() | ||
21 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data | ||
22 | |||
23 | return new_embedding | ||
24 | |||
25 | |||
26 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | ||
27 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | ||
28 | super().__init__(config) | ||
29 | |||
30 | self.token_embedding = embeddings.token_embedding | ||
31 | self.position_embedding = embeddings.position_embedding | ||
32 | |||
33 | self.temp_token_embedding = nn.Embedding( | ||
34 | self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | ||
35 | self.temp_token_embedding.weight.data.zero_() | ||
36 | self.temp_token_ids = torch.tensor([]) | ||
37 | |||
38 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | ||
39 | if isinstance(token_ids, int): | ||
40 | token_ids = [token_ids] | ||
41 | |||
42 | if initializer is not None: | ||
43 | if isinstance(initializer, int): | ||
44 | initializer = [initializer] | ||
45 | |||
46 | if isinstance(initializer, list): | ||
47 | initializer = (initializer * len(token_ids))[:len(token_ids)] | ||
48 | |||
49 | with torch.no_grad(): | ||
50 | initializer = self.get_embed(initializer) | ||
51 | |||
52 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) | ||
53 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) | ||
54 | |||
55 | token_ids = torch.tensor(token_ids) | ||
56 | |||
57 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | ||
58 | |||
59 | if initializer is not None: | ||
60 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
61 | else: | ||
62 | self.temp_token_embedding.weight.data[token_ids].zero_() | ||
63 | |||
64 | def load_embed(self, input_ids: list[int], filename: Path): | ||
65 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
66 | self.add_embed(input_ids, file.get_tensor("embed")) | ||
67 | |||
68 | def save_embed(self, input_ids: list[int], filename: Path): | ||
69 | save_file({"embed": self.get_embed(input_ids)}, filename) | ||
70 | |||
71 | def make_permanent(self): | ||
72 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | ||
73 | self.temp_token_ids = torch.tensor([]) | ||
74 | |||
75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | ||
76 | if isinstance(input_ids, list): | ||
77 | input_ids = torch.tensor(input_ids) | ||
78 | |||
79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) | ||
80 | |||
81 | embeds = self.token_embedding(input_ids) | ||
82 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | ||
83 | |||
84 | return embeds | ||
85 | |||
86 | def forward( | ||
87 | self, | ||
88 | input_ids: Optional[torch.LongTensor] = None, | ||
89 | position_ids: Optional[torch.LongTensor] = None, | ||
90 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
91 | ) -> torch.Tensor: | ||
92 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
93 | |||
94 | if position_ids is None: | ||
95 | position_ids = self.position_ids[:, :seq_length] | ||
96 | |||
97 | if inputs_embeds is None: | ||
98 | inputs_embeds = self.get_embed(input_ids) | ||
99 | |||
100 | position_embeddings = self.position_embedding(position_ids) | ||
101 | embeddings = inputs_embeds + position_embeddings | ||
102 | |||
103 | return embeddings | ||
104 | |||
105 | |||
106 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | ||
107 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | ||
108 | text_encoder.text_model.embeddings = text_embeddings | ||
109 | return text_embeddings | ||