diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-27 10:16:10 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-27 10:16:10 +0200 |
| commit | a412196d1a3b616655de52fb12e0d8528e1f1af0 (patch) | |
| tree | a97dec23f16231b5ca8c1092a009c983b6971880 | |
| parent | Sparse TI embeddings without sparse tensors (diff) | |
| download | textual-inversion-diff-a412196d1a3b616655de52fb12e0d8528e1f1af0.tar.gz textual-inversion-diff-a412196d1a3b616655de52fb12e0d8528e1f1af0.tar.bz2 textual-inversion-diff-a412196d1a3b616655de52fb12e0d8528e1f1af0.zip | |
Revert to regular embeddings
| -rw-r--r-- | models/clip/embeddings.py | 34 |
1 files changed, 15 insertions, 19 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 95904cf..2b315c4 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -42,16 +42,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 42 | self.init_temp_embeddings() | 42 | self.init_temp_embeddings() |
| 43 | 43 | ||
| 44 | def init_temp_embeddings(self): | 44 | def init_temp_embeddings(self): |
| 45 | self.temp_token_embedding = nn.ParameterList() | 45 | self.temp_token_embedding = nn.Embedding( |
| 46 | 0, | ||
| 47 | self.token_embedding.embedding_dim, | ||
| 48 | device=self.token_embedding.weight.device, | ||
| 49 | dtype=self.token_embedding.weight.dtype | ||
| 50 | ) | ||
| 46 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 51 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 47 | 52 | ||
| 48 | def resize(self, size: int): | 53 | def resize(self, size: int): |
| 49 | for _ in range(len(self.temp_token_embedding), size): | 54 | self.temp_token_embedding = resize_embedding( |
| 50 | self.temp_token_embedding.append(torch.zeros( | 55 | self.temp_token_embedding, |
| 51 | self.token_embedding.embedding_dim, | 56 | size - self.num_permanent_embeddings, |
| 52 | device=self.token_embedding.weight.device, | 57 | self.initializer_factor |
| 53 | dtype=self.token_embedding.weight.dtype, | 58 | ) |
| 54 | )) | ||
| 55 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 59 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 56 | 60 | ||
| 57 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 61 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
| @@ -78,10 +82,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 78 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 82 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 79 | 83 | ||
| 80 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 84 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 81 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) | 85 | mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) |
| 82 | 86 | self.temp_token_embedding.weight.data[mask] = initializer | |
| 83 | for i, id in enumerate(mask): | ||
| 84 | self.temp_token_embedding[id] = initializer[i] | ||
| 85 | 87 | ||
| 86 | def load_embed(self, input_ids: list[int], filename: Path): | 88 | def load_embed(self, input_ids: list[int], filename: Path): |
| 87 | with safe_open(filename, framework="pt", device="cpu") as file: | 89 | with safe_open(filename, framework="pt", device="cpu") as file: |
| @@ -91,8 +93,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 91 | save_file({"embed": self.get_embed(input_ids)}, filename) | 93 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 92 | 94 | ||
| 93 | def persist(self): | 95 | def persist(self): |
| 94 | for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): | 96 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
| 95 | self.token_embedding.weight.data[id] = emb | ||
| 96 | self.num_permanent_embeddings = self.token_embedding.num_embeddings | 97 | self.num_permanent_embeddings = self.token_embedding.num_embeddings |
| 97 | self.init_temp_embeddings() | 98 | self.init_temp_embeddings() |
| 98 | 99 | ||
| @@ -111,12 +112,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 111 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | 112 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) |
| 112 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | 113 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() |
| 113 | 114 | ||
| 114 | if len(temp_token_ids): | 115 | embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) |
| 115 | embeds_override = torch.stack([ | ||
| 116 | self.temp_token_embedding[id] | ||
| 117 | for id in temp_token_ids | ||
| 118 | ]) | ||
| 119 | embeds[embeds_mask] = embeds_override | ||
| 120 | 116 | ||
| 121 | return embeds | 117 | return embeds |
| 122 | 118 | ||
