From 89d471652644f449966a0cd944041c98dab7f66c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 07:25:24 +0100 Subject: Code deduplication --- models/clip/embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 46b414b..9a23a2a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -99,12 +99,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeds - def normalize(self, lambda_: float = 1.0): + def normalize(self, target: float = 0.4, lambda_: float = 1.0): w = self.temp_token_embedding.weight pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) w[self.temp_token_ids] = F.normalize( w[self.temp_token_ids, :], dim=-1 - ) * (pre_norm + lambda_ * (0.4 - pre_norm)) + ) * (pre_norm + lambda_ * (target - pre_norm)) def forward( self, -- cgit v1.2.3-54-g00ecf