diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9d8f770..46b414b 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -3,6 +3,7 @@ from pathlib import Path | |||
3 | 3 | ||
4 | import torch | 4 | import torch |
5 | import torch.nn as nn | 5 | import torch.nn as nn |
6 | import torch.nn.functional as F | ||
6 | 7 | ||
7 | from safetensors import safe_open | 8 | from safetensors import safe_open |
8 | from safetensors.torch import save_file | 9 | from safetensors.torch import save_file |
@@ -45,7 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
45 | device=self.token_embedding.weight.device, | 46 | device=self.token_embedding.weight.device, |
46 | dtype=self.token_embedding.weight.dtype | 47 | dtype=self.token_embedding.weight.dtype |
47 | ) | 48 | ) |
48 | self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02) | 49 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() |
49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 50 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
50 | 51 | ||
51 | def resize(self, size: int): | 52 | def resize(self, size: int): |
@@ -98,6 +99,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
98 | 99 | ||
99 | return embeds | 100 | return embeds |
100 | 101 | ||
102 | def normalize(self, lambda_: float = 1.0): | ||
103 | w = self.temp_token_embedding.weight | ||
104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | ||
105 | w[self.temp_token_ids] = F.normalize( | ||
106 | w[self.temp_token_ids, :], dim=-1 | ||
107 | ) * (pre_norm + lambda_ * (0.4 - pre_norm)) | ||
108 | |||
101 | def forward( | 109 | def forward( |
102 | self, | 110 | self, |
103 | input_ids: Optional[torch.LongTensor] = None, | 111 | input_ids: Optional[torch.LongTensor] = None, |