summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 1cc59d9..6c41c33 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -3,7 +3,6 @@ from pathlib import Path
3 3
4import torch 4import torch
5import torch.nn as nn 5import torch.nn as nn
6import torch.nn.functional as F
7 6
8from safetensors import safe_open 7from safetensors import safe_open
9from safetensors.torch import save_file 8from safetensors.torch import save_file
@@ -104,10 +103,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
104 return 103 return
105 104
106 w = self.temp_token_embedding.weight 105 w = self.temp_token_embedding.weight
107 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) 106 norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
108 w[self.temp_token_ids] = F.normalize( 107 w[self.temp_token_ids].add_((w[self.temp_token_ids] / norm.clamp_min(1e-12)) * lambda_ * (target - norm))
109 w[self.temp_token_ids, :], dim=-1
110 ) * (pre_norm + lambda_ * (target - pre_norm))
111 108
112 def forward( 109 def forward(
113 self, 110 self,