summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py10
-rw-r--r--train_ti.py11
2 files changed, 12 insertions, 9 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9d8f770..46b414b 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -3,6 +3,7 @@ from pathlib import Path
3 3
4import torch 4import torch
5import torch.nn as nn 5import torch.nn as nn
6import torch.nn.functional as F
6 7
7from safetensors import safe_open 8from safetensors import safe_open
8from safetensors.torch import save_file 9from safetensors.torch import save_file
@@ -45,7 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
45 device=self.token_embedding.weight.device, 46 device=self.token_embedding.weight.device,
46 dtype=self.token_embedding.weight.dtype 47 dtype=self.token_embedding.weight.dtype
47 ) 48 )
48 self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02) 49 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
49 self.temp_token_ids = torch.tensor([], dtype=torch.long) 50 self.temp_token_ids = torch.tensor([], dtype=torch.long)
50 51
51 def resize(self, size: int): 52 def resize(self, size: int):
@@ -98,6 +99,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
98 99
99 return embeds 100 return embeds
100 101
102 def normalize(self, lambda_: float = 1.0):
103 w = self.temp_token_embedding.weight
104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
105 w[self.temp_token_ids] = F.normalize(
106 w[self.temp_token_ids, :], dim=-1
107 ) * (pre_norm + lambda_ * (0.4 - pre_norm))
108
101 def forward( 109 def forward(
102 self, 110 self,
103 input_ids: Optional[torch.LongTensor] = None, 111 input_ids: Optional[torch.LongTensor] = None,
diff --git a/train_ti.py b/train_ti.py
index 2c5037f..890c465 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -7,7 +7,6 @@ from pathlib import Path
7from contextlib import contextmanager, nullcontext 7from contextlib import contextmanager, nullcontext
8 8
9import torch 9import torch
10import torch.nn.functional as F
11import torch.utils.checkpoint 10import torch.utils.checkpoint
12 11
13from accelerate import Accelerator 12from accelerate import Accelerator
@@ -166,7 +165,7 @@ def parse_args():
166 parser.add_argument( 165 parser.add_argument(
167 "--tag_dropout", 166 "--tag_dropout",
168 type=float, 167 type=float,
169 default=0.1, 168 default=0,
170 help="Tag dropout probability.", 169 help="Tag dropout probability.",
171 ) 170 )
172 parser.add_argument( 171 parser.add_argument(
@@ -177,7 +176,7 @@ def parse_args():
177 parser.add_argument( 176 parser.add_argument(
178 "--vector_dropout", 177 "--vector_dropout",
179 type=int, 178 type=int,
180 default=0.1, 179 default=0,
181 help="Vector dropout probability.", 180 help="Vector dropout probability.",
182 ) 181 )
183 parser.add_argument( 182 parser.add_argument(
@@ -869,11 +868,7 @@ def main():
869 868
870 @torch.no_grad() 869 @torch.no_grad()
871 def on_clip(lr): 870 def on_clip(lr):
872 embeddings = text_encoder.text_model.embeddings.temp_token_embedding 871 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr))
873
874 pre_norm = embeddings.weight.norm(dim=-1, keepdim=True)
875 lambda_ = min(1.0, 100 * lr)
876 embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm))
877 872
878 loop = partial( 873 loop = partial(
879 loss_step, 874 loss_step,