summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-04 07:30:43 +0200
committerVolpeon <git@volpeon.ink>2023-04-04 07:30:43 +0200
commit30b557c8e1f03b4748ac3efca599ff51d66561cb (patch)
tree59aaacde83a7a44dc267c64455f6dc2cfb90c01f /models/clip
parentImproved sparse embeddings (diff)
downloadtextual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.gz
textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.bz2
textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.zip
TI: Bring back old embedding decay
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py15
1 files changed, 7 insertions, 8 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index a356434..63a141f 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
37 37
38 38
39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): 40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings):
41 super().__init__(config) 41 super().__init__(config)
42 42
43 self.token_embedding = embeddings.token_embedding 43 self.token_embedding = embeddings.token_embedding
@@ -49,7 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
49 device=self.token_embedding.weight.device, 49 device=self.token_embedding.weight.device,
50 dtype=self.token_embedding.weight.dtype, 50 dtype=self.token_embedding.weight.dtype,
51 ) 51 )
52 self.alpha = alpha
53 52
54 def resize(self, size: int): 53 def resize(self, size: int):
55 self.token_override_embedding.resize(size) 54 self.token_override_embedding.resize(size)
@@ -87,7 +86,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
87 token_ids = torch.tensor(token_ids, dtype=torch.long) 86 token_ids = torch.tensor(token_ids, dtype=torch.long)
88 87
89 self.token_embedding.weight.data[token_ids] = initializer 88 self.token_embedding.weight.data[token_ids] = initializer
90 self.token_override_embedding.set(token_ids) 89 self.token_override_embedding.set(token_ids, initializer)
91 90
92 def load_embed(self, input_ids: list[int], filename: Path): 91 def load_embed(self, input_ids: list[int], filename: Path):
93 with safe_open(filename, framework="pt", device="cpu") as file: 92 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -101,8 +100,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
101 embs, mask = self.token_override_embedding(input_ids) 100 embs, mask = self.token_override_embedding(input_ids)
102 if embs is not None: 101 if embs is not None:
103 input_ids = input_ids[mask] 102 input_ids = input_ids[mask]
104 self.token_embedding.weight.data[input_ids] += self.alpha * embs 103 self.token_embedding.weight.data[input_ids] = embs
105 self.token_override_embedding.unset(input_ids) 104 self.token_override_embedding.unset(input_ids)
106 105
107 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 106 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
108 if isinstance(input_ids, list): 107 if isinstance(input_ids, list):
@@ -111,7 +110,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
111 embs = self.token_embedding(input_ids) 110 embs = self.token_embedding(input_ids)
112 embs_override, mask = self.token_override_embedding(input_ids) 111 embs_override, mask = self.token_override_embedding(input_ids)
113 if embs_override is not None: 112 if embs_override is not None:
114 embs[mask] += self.alpha * embs_override 113 embs[mask] = embs_override
115 114
116 return embs 115 return embs
117 116
@@ -135,7 +134,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
135 return embeddings 134 return embeddings
136 135
137 136
138def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: 137def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings:
139 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) 138 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
140 text_encoder.text_model.embeddings = text_embeddings 139 text_encoder.text_model.embeddings = text_embeddings
141 return text_embeddings 140 return text_embeddings