diff options
author | Volpeon <git@volpeon.ink> | 2023-04-04 07:30:43 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-04 07:30:43 +0200 |
commit | 30b557c8e1f03b4748ac3efca599ff51d66561cb (patch) | |
tree | 59aaacde83a7a44dc267c64455f6dc2cfb90c01f /models/clip | |
parent | Improved sparse embeddings (diff) | |
download | textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.gz textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.bz2 textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.zip |
TI: Bring back old embedding decay
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index a356434..63a141f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
37 | 37 | ||
38 | 38 | ||
39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): | 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): |
41 | super().__init__(config) | 41 | super().__init__(config) |
42 | 42 | ||
43 | self.token_embedding = embeddings.token_embedding | 43 | self.token_embedding = embeddings.token_embedding |
@@ -49,7 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
49 | device=self.token_embedding.weight.device, | 49 | device=self.token_embedding.weight.device, |
50 | dtype=self.token_embedding.weight.dtype, | 50 | dtype=self.token_embedding.weight.dtype, |
51 | ) | 51 | ) |
52 | self.alpha = alpha | ||
53 | 52 | ||
54 | def resize(self, size: int): | 53 | def resize(self, size: int): |
55 | self.token_override_embedding.resize(size) | 54 | self.token_override_embedding.resize(size) |
@@ -87,7 +86,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 86 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
88 | 87 | ||
89 | self.token_embedding.weight.data[token_ids] = initializer | 88 | self.token_embedding.weight.data[token_ids] = initializer |
90 | self.token_override_embedding.set(token_ids) | 89 | self.token_override_embedding.set(token_ids, initializer) |
91 | 90 | ||
92 | def load_embed(self, input_ids: list[int], filename: Path): | 91 | def load_embed(self, input_ids: list[int], filename: Path): |
93 | with safe_open(filename, framework="pt", device="cpu") as file: | 92 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -101,8 +100,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
101 | embs, mask = self.token_override_embedding(input_ids) | 100 | embs, mask = self.token_override_embedding(input_ids) |
102 | if embs is not None: | 101 | if embs is not None: |
103 | input_ids = input_ids[mask] | 102 | input_ids = input_ids[mask] |
104 | self.token_embedding.weight.data[input_ids] += self.alpha * embs | 103 | self.token_embedding.weight.data[input_ids] = embs |
105 | self.token_override_embedding.unset(input_ids) | 104 | self.token_override_embedding.unset(input_ids) |
106 | 105 | ||
107 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 106 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
108 | if isinstance(input_ids, list): | 107 | if isinstance(input_ids, list): |
@@ -111,7 +110,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
111 | embs = self.token_embedding(input_ids) | 110 | embs = self.token_embedding(input_ids) |
112 | embs_override, mask = self.token_override_embedding(input_ids) | 111 | embs_override, mask = self.token_override_embedding(input_ids) |
113 | if embs_override is not None: | 112 | if embs_override is not None: |
114 | embs[mask] += self.alpha * embs_override | 113 | embs[mask] = embs_override |
115 | 114 | ||
116 | return embs | 115 | return embs |
117 | 116 | ||
@@ -135,7 +134,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
135 | return embeddings | 134 | return embeddings |
136 | 135 | ||
137 | 136 | ||
138 | def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: | 137 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: |
139 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) | 138 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) |
140 | text_encoder.text_model.embeddings = text_embeddings | 139 | text_encoder.text_model.embeddings = text_embeddings |
141 | return text_embeddings | 140 | return text_embeddings |