summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9d8f770..46b414b 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -3,6 +3,7 @@ from pathlib import Path
3 3
4import torch 4import torch
5import torch.nn as nn 5import torch.nn as nn
6import torch.nn.functional as F
6 7
7from safetensors import safe_open 8from safetensors import safe_open
8from safetensors.torch import save_file 9from safetensors.torch import save_file
@@ -45,7 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
45 device=self.token_embedding.weight.device, 46 device=self.token_embedding.weight.device,
46 dtype=self.token_embedding.weight.dtype 47 dtype=self.token_embedding.weight.dtype
47 ) 48 )
48 self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02) 49 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
49 self.temp_token_ids = torch.tensor([], dtype=torch.long) 50 self.temp_token_ids = torch.tensor([], dtype=torch.long)
50 51
51 def resize(self, size: int): 52 def resize(self, size: int):
@@ -98,6 +99,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
98 99
99 return embeds 100 return embeds
100 101
102 def normalize(self, lambda_: float = 1.0):
103 w = self.temp_token_embedding.weight
104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
105 w[self.temp_token_ids] = F.normalize(
106 w[self.temp_token_ids, :], dim=-1
107 ) * (pre_norm + lambda_ * (0.4 - pre_norm))
108
101 def forward( 109 def forward(
102 self, 110 self,
103 input_ids: Optional[torch.LongTensor] = None, 111 input_ids: Optional[torch.LongTensor] = None,