diff options
author | Volpeon <git@volpeon.ink> | 2022-12-31 17:12:12 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-31 17:12:12 +0100 |
commit | b42e7fbc29fd8045a2b932eb8ae76587f51f7513 (patch) | |
tree | 85321e605cd8e183a0b9e05efcc4282921e667e0 /models/clip/embeddings.py | |
parent | Simplified multi-vector embedding code (diff) | |
download | textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.gz textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.bz2 textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.zip |
Bugfixes for multi-vector token handling
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r-- | models/clip/embeddings.py | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f82873e..91a575d 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -15,8 +15,12 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | |||
15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: | 15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: |
16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() | 16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() |
17 | 17 | ||
18 | new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) | 18 | new_embedding = nn.Embedding( |
19 | new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) | 19 | old_num_embeddings + n, |
20 | old_embedding_dim, | ||
21 | device=old_embedding.weight.device, | ||
22 | dtype=old_embedding.weight.dtype | ||
23 | ) | ||
20 | new_embedding.weight.data.zero_() | 24 | new_embedding.weight.data.zero_() |
21 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data | 25 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data |
22 | 26 | ||
@@ -31,9 +35,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
31 | self.position_embedding = embeddings.position_embedding | 35 | self.position_embedding = embeddings.position_embedding |
32 | 36 | ||
33 | self.temp_token_embedding = nn.Embedding( | 37 | self.temp_token_embedding = nn.Embedding( |
34 | self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | 38 | self.token_embedding.num_embeddings, |
39 | self.token_embedding.embedding_dim, | ||
40 | device=self.token_embedding.weight.device, | ||
41 | dtype=self.token_embedding.weight.dtype | ||
42 | ) | ||
35 | self.temp_token_embedding.weight.data.zero_() | 43 | self.temp_token_embedding.weight.data.zero_() |
36 | self.temp_token_ids = torch.tensor([]) | 44 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
37 | 45 | ||
38 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 46 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
39 | if isinstance(token_ids, int): | 47 | if isinstance(token_ids, int): |
@@ -52,12 +60,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
52 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) | 60 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) |
53 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) | 61 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) |
54 | 62 | ||
55 | token_ids = torch.tensor(token_ids) | 63 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
56 | 64 | ||
57 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 65 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
58 | 66 | ||
59 | if initializer is not None: | 67 | if initializer is not None: |
60 | self.temp_token_embedding.weight.data[token_ids] = initializer | 68 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( |
69 | dtype=self.temp_token_embedding.weight.dtype) | ||
61 | else: | 70 | else: |
62 | self.temp_token_embedding.weight.data[token_ids].zero_() | 71 | self.temp_token_embedding.weight.data[token_ids].zero_() |
63 | 72 | ||
@@ -70,13 +79,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
70 | 79 | ||
71 | def make_permanent(self): | 80 | def make_permanent(self): |
72 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 81 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
73 | self.temp_token_ids = torch.tensor([]) | 82 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
74 | 83 | ||
75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 84 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
76 | if isinstance(input_ids, list): | 85 | if isinstance(input_ids, list): |
77 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device) | 86 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
78 | 87 | ||
79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) | 88 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
80 | 89 | ||
81 | embeds = self.token_embedding(input_ids) | 90 | embeds = self.token_embedding(input_ids) |
82 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 91 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] |