summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 08:13:39 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 08:13:39 +0100
commit1abbfd5215a99dba9d699e91baec00e6f02a0bd5 (patch)
tree670e846b3c08bd8957955ea56d3a4c4b58a8ad6f /models/clip
parentUpdate (diff)
downloadtextual-inversion-diff-1abbfd5215a99dba9d699e91baec00e6f02a0bd5.tar.gz
textual-inversion-diff-1abbfd5215a99dba9d699e91baec00e6f02a0bd5.tar.bz2
textual-inversion-diff-1abbfd5215a99dba9d699e91baec00e6f02a0bd5.zip
Update
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9a23a2a..1cc59d9 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -100,6 +100,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
100 return embeds 100 return embeds
101 101
102 def normalize(self, target: float = 0.4, lambda_: float = 1.0): 102 def normalize(self, target: float = 0.4, lambda_: float = 1.0):
103 if lambda_ == 0:
104 return
105
103 w = self.temp_token_embedding.weight 106 w = self.temp_token_embedding.weight
104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) 107 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
105 w[self.temp_token_ids] = F.normalize( 108 w[self.temp_token_ids] = F.normalize(