diff options
Diffstat (limited to 'models/clip/embeddings.py')
-rw-r--r-- | models/clip/embeddings.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9a23a2a..761efbc 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -40,6 +40,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
40 | self.position_embedding = embeddings.position_embedding | 40 | self.position_embedding = embeddings.position_embedding |
41 | self.initializer_factor = config.initializer_factor | 41 | self.initializer_factor = config.initializer_factor |
42 | 42 | ||
43 | self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item() | ||
44 | |||
43 | self.temp_token_embedding = nn.Embedding( | 45 | self.temp_token_embedding = nn.Embedding( |
44 | self.token_embedding.num_embeddings, | 46 | self.token_embedding.num_embeddings, |
45 | self.token_embedding.embedding_dim, | 47 | self.token_embedding.embedding_dim, |
@@ -99,7 +101,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
99 | 101 | ||
100 | return embeds | 102 | return embeds |
101 | 103 | ||
102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): | 104 | def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): |
105 | if target is None: | ||
106 | target = self.decay_target | ||
103 | w = self.temp_token_embedding.weight | 107 | w = self.temp_token_embedding.weight |
104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 108 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
105 | w[self.temp_token_ids] = F.normalize( | 109 | w[self.temp_token_ids] = F.normalize( |