diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 6 |
1 files changed, 1 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 761efbc..9a23a2a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -40,8 +40,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
40 | self.position_embedding = embeddings.position_embedding | 40 | self.position_embedding = embeddings.position_embedding |
41 | self.initializer_factor = config.initializer_factor | 41 | self.initializer_factor = config.initializer_factor |
42 | 42 | ||
43 | self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item() | ||
44 | |||
45 | self.temp_token_embedding = nn.Embedding( | 43 | self.temp_token_embedding = nn.Embedding( |
46 | self.token_embedding.num_embeddings, | 44 | self.token_embedding.num_embeddings, |
47 | self.token_embedding.embedding_dim, | 45 | self.token_embedding.embedding_dim, |
@@ -101,9 +99,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
101 | 99 | ||
102 | return embeds | 100 | return embeds |
103 | 101 | ||
104 | def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): | 102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): |
105 | if target is None: | ||
106 | target = self.decay_target | ||
107 | w = self.temp_token_embedding.weight | 103 | w = self.temp_token_embedding.weight |
108 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
109 | w[self.temp_token_ids] = F.normalize( | 105 | w[self.temp_token_ids] = F.normalize( |