diff options
author | Volpeon <git@volpeon.ink> | 2022-12-31 23:09:41 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-31 23:09:41 +0100 |
commit | 56edf85c8b80d49c998bcf26392cce50d552137a (patch) | |
tree | c4fa91f1dc951329a6d276731308d657eec644c8 | |
parent | Bugfixes for multi-vector token handling (diff) | |
download | textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.tar.gz textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.tar.bz2 textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.zip |
Update
-rw-r--r-- | common.py | 24 | ||||
-rw-r--r-- | models/clip/embeddings.py | 30 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 6 | ||||
-rw-r--r-- | train_ti.py | 1 |
4 files changed, 35 insertions, 26 deletions
@@ -24,13 +24,21 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
24 | return [] | 24 | return [] |
25 | 25 | ||
26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] |
27 | tokens = [filename.stem for filename in filenames] | ||
28 | 27 | ||
29 | for filename in embeddings_dir.iterdir(): | 28 | new_tokens = [] |
30 | if filename.is_file(): | 29 | new_embeds = [] |
31 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
32 | embed = file.get_tensor("embed") | ||
33 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | ||
34 | embeddings.add_embed(added.ids, embed) | ||
35 | 30 | ||
36 | return tokens | 31 | for filename in filenames: |
32 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
33 | embed = file.get_tensor("embed") | ||
34 | |||
35 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | ||
36 | new_tokens.append(added) | ||
37 | new_embeds.append(embed) | ||
38 | |||
39 | embeddings.resize(len(tokenizer)) | ||
40 | |||
41 | for (new_token, embeds) in zip(new_tokens, new_embeds): | ||
42 | embeddings.add_embed(new_token.ids, embeds) | ||
43 | |||
44 | return new_tokens | ||
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 91a575d..cab1515 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -12,18 +12,22 @@ from transformers.models.clip import CLIPTextConfig | |||
12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
13 | 13 | ||
14 | 14 | ||
15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: | 15 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding: |
16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() | 16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() |
17 | 17 | ||
18 | if old_num_embeddings == new_num_embeddings: | ||
19 | return old_embedding | ||
20 | |||
21 | n = min(old_num_embeddings, new_num_embeddings) | ||
22 | |||
18 | new_embedding = nn.Embedding( | 23 | new_embedding = nn.Embedding( |
19 | old_num_embeddings + n, | 24 | new_num_embeddings, |
20 | old_embedding_dim, | 25 | old_embedding_dim, |
21 | device=old_embedding.weight.device, | 26 | device=old_embedding.weight.device, |
22 | dtype=old_embedding.weight.dtype | 27 | dtype=old_embedding.weight.dtype |
23 | ) | 28 | ) |
24 | new_embedding.weight.data.zero_() | 29 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) |
25 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data | 30 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] |
26 | |||
27 | return new_embedding | 31 | return new_embedding |
28 | 32 | ||
29 | 33 | ||
@@ -40,9 +44,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
40 | device=self.token_embedding.weight.device, | 44 | device=self.token_embedding.weight.device, |
41 | dtype=self.token_embedding.weight.dtype | 45 | dtype=self.token_embedding.weight.dtype |
42 | ) | 46 | ) |
43 | self.temp_token_embedding.weight.data.zero_() | 47 | self.temp_token_embedding.weight.data.normal_(mean=0.0, std=config.initializer_factor * 0.02) |
44 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 48 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
45 | 49 | ||
50 | def resize(self, size: int): | ||
51 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.config.initializer_factor) | ||
52 | self.token_embedding = resize_embedding(self.token_embedding, size, self.config.initializer_factor) | ||
53 | |||
46 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 54 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
47 | if isinstance(token_ids, int): | 55 | if isinstance(token_ids, int): |
48 | token_ids = [token_ids] | 56 | token_ids = [token_ids] |
@@ -55,20 +63,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
55 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 63 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
56 | 64 | ||
57 | with torch.no_grad(): | 65 | with torch.no_grad(): |
58 | initializer = self.get_embed(initializer) | 66 | initializer = self.get_embed(initializer).to(dtype=self.temp_token_embedding.weight.dtype) |
59 | |||
60 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) | ||
61 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) | ||
62 | 67 | ||
63 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 68 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
64 | 69 | ||
65 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 70 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
66 | 71 | ||
67 | if initializer is not None: | 72 | if initializer is not None: |
68 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( | 73 | self.temp_token_embedding.weight.data[token_ids] = initializer |
69 | dtype=self.temp_token_embedding.weight.dtype) | ||
70 | else: | ||
71 | self.temp_token_embedding.weight.data[token_ids].zero_() | ||
72 | 74 | ||
73 | def load_embed(self, input_ids: list[int], filename: Path): | 75 | def load_embed(self, input_ids: list[int], filename: Path): |
74 | with safe_open(filename, framework="pt", device="cpu") as file: | 76 | with safe_open(filename, framework="pt", device="cpu") as file: |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 63566e0..fbfe790 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -8,7 +8,6 @@ from transformers import CLIPTokenizer | |||
8 | 8 | ||
9 | class MultiCLIPTokenizerItem(NamedTuple): | 9 | class MultiCLIPTokenizerItem(NamedTuple): |
10 | token: str | 10 | token: str |
11 | meta_id: int | ||
12 | ids: list[int] | 11 | ids: list[int] |
13 | 12 | ||
14 | 13 | ||
@@ -38,11 +37,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
38 | super().add_tokens(multi_token) | 37 | super().add_tokens(multi_token) |
39 | 38 | ||
40 | ids = super().convert_tokens_to_ids(multi_token) | 39 | ids = super().convert_tokens_to_ids(multi_token) |
41 | meta_id = ids[0] | ||
42 | 40 | ||
43 | self.token_map[meta_id] = ids | 41 | self.token_map[ids[0]] = ids |
44 | 42 | ||
45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) | 43 | return MultiCLIPTokenizerItem(new_tokens, ids) |
46 | 44 | ||
47 | def expand_id(self, id: int, vector_shuffle=True): | 45 | def expand_id(self, id: int, vector_shuffle=True): |
48 | if id in self.token_map: | 46 | if id in self.token_map: |
diff --git a/train_ti.py b/train_ti.py index 3776eb2..19348e5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -535,6 +535,7 @@ def main(): | |||
535 | ] | 535 | ] |
536 | 536 | ||
537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
538 | embeddings.resize(len(tokenizer)) | ||
538 | 539 | ||
539 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): | 540 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
540 | embeddings.add_embed(new_token.ids, init_ids) | 541 | embeddings.add_embed(new_token.ids, init_ids) |