diff options
author | Volpeon <git@volpeon.ink> | 2023-01-13 18:59:26 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-13 18:59:26 +0100 |
commit | 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch) | |
tree | 61cb98adbf33ed08506601f8b70f1b62bc42c4ee /models/clip | |
parent | Simplified step calculations (diff) | |
download | textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2 textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip |
More modularization
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 6 |
1 files changed, 1 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 761efbc..9a23a2a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -40,8 +40,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
40 | self.position_embedding = embeddings.position_embedding | 40 | self.position_embedding = embeddings.position_embedding |
41 | self.initializer_factor = config.initializer_factor | 41 | self.initializer_factor = config.initializer_factor |
42 | 42 | ||
43 | self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item() | ||
44 | |||
45 | self.temp_token_embedding = nn.Embedding( | 43 | self.temp_token_embedding = nn.Embedding( |
46 | self.token_embedding.num_embeddings, | 44 | self.token_embedding.num_embeddings, |
47 | self.token_embedding.embedding_dim, | 45 | self.token_embedding.embedding_dim, |
@@ -101,9 +99,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
101 | 99 | ||
102 | return embeds | 100 | return embeds |
103 | 101 | ||
104 | def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): | 102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): |
105 | if target is None: | ||
106 | target = self.decay_target | ||
107 | w = self.temp_token_embedding.weight | 103 | w = self.temp_token_embedding.weight |
108 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
109 | w[self.temp_token_ids] = F.normalize( | 105 | w[self.temp_token_ids] = F.normalize( |