summaryrefslogtreecommitdiffstats
path: root/models/clip/embeddings.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
commit127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch)
tree61cb98adbf33ed08506601f8b70f1b62bc42c4ee /models/clip/embeddings.py
parentSimplified step calculations (diff)
downloadtextual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip
More modularization
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r--models/clip/embeddings.py6
1 files changed, 1 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 761efbc..9a23a2a 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -40,8 +40,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 self.position_embedding = embeddings.position_embedding 40 self.position_embedding = embeddings.position_embedding
41 self.initializer_factor = config.initializer_factor 41 self.initializer_factor = config.initializer_factor
42 42
43 self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item()
44
45 self.temp_token_embedding = nn.Embedding( 43 self.temp_token_embedding = nn.Embedding(
46 self.token_embedding.num_embeddings, 44 self.token_embedding.num_embeddings,
47 self.token_embedding.embedding_dim, 45 self.token_embedding.embedding_dim,
@@ -101,9 +99,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
101 99
102 return embeds 100 return embeds
103 101
104 def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): 102 def normalize(self, target: float = 0.4, lambda_: float = 1.0):
105 if target is None:
106 target = self.decay_target
107 w = self.temp_token_embedding.weight 103 w = self.temp_token_embedding.weight
108 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) 104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
109 w[self.temp_token_ids] = F.normalize( 105 w[self.temp_token_ids] = F.normalize(