summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-21 09:09:50 +0100
committerVolpeon <git@volpeon.ink>2023-02-21 09:09:50 +0100
commit16b92605a59d59c65789c89b54bb97da51908056 (patch)
treeb0cbf8677897c3f44c736b710fd034eb2c5de6a0 /models
parentUpdate (diff)
downloadtextual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.gz
textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.bz2
textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.zip
Embedding normalization: Ignore tensors with grad = 0
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py8
1 files changed, 0 insertions, 8 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 6c41c33..734730e 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -98,14 +98,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
98 98
99 return embeds 99 return embeds
100 100
101 def normalize(self, target: float = 0.4, lambda_: float = 1.0):
102 if lambda_ == 0:
103 return
104
105 w = self.temp_token_embedding.weight
106 norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
107 w[self.temp_token_ids].add_((w[self.temp_token_ids] / norm.clamp_min(1e-12)) * lambda_ * (target - norm))
108
109 def forward( 101 def forward(
110 self, 102 self,
111 input_ids: Optional[torch.LongTensor] = None, 103 input_ids: Optional[torch.LongTensor] = None,