diff options
| -rw-r--r-- | models/clip/embeddings.py | 10 | ||||
| -rw-r--r-- | train_ti.py | 11 |
2 files changed, 12 insertions, 9 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9d8f770..46b414b 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -3,6 +3,7 @@ from pathlib import Path | |||
| 3 | 3 | ||
| 4 | import torch | 4 | import torch |
| 5 | import torch.nn as nn | 5 | import torch.nn as nn |
| 6 | import torch.nn.functional as F | ||
| 6 | 7 | ||
| 7 | from safetensors import safe_open | 8 | from safetensors import safe_open |
| 8 | from safetensors.torch import save_file | 9 | from safetensors.torch import save_file |
| @@ -45,7 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 45 | device=self.token_embedding.weight.device, | 46 | device=self.token_embedding.weight.device, |
| 46 | dtype=self.token_embedding.weight.dtype | 47 | dtype=self.token_embedding.weight.dtype |
| 47 | ) | 48 | ) |
| 48 | self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02) | 49 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() |
| 49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 50 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 50 | 51 | ||
| 51 | def resize(self, size: int): | 52 | def resize(self, size: int): |
| @@ -98,6 +99,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 98 | 99 | ||
| 99 | return embeds | 100 | return embeds |
| 100 | 101 | ||
| 102 | def normalize(self, lambda_: float = 1.0): | ||
| 103 | w = self.temp_token_embedding.weight | ||
| 104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | ||
| 105 | w[self.temp_token_ids] = F.normalize( | ||
| 106 | w[self.temp_token_ids, :], dim=-1 | ||
| 107 | ) * (pre_norm + lambda_ * (0.4 - pre_norm)) | ||
| 108 | |||
| 101 | def forward( | 109 | def forward( |
| 102 | self, | 110 | self, |
| 103 | input_ids: Optional[torch.LongTensor] = None, | 111 | input_ids: Optional[torch.LongTensor] = None, |
diff --git a/train_ti.py b/train_ti.py index 2c5037f..890c465 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -7,7 +7,6 @@ from pathlib import Path | |||
| 7 | from contextlib import contextmanager, nullcontext | 7 | from contextlib import contextmanager, nullcontext |
| 8 | 8 | ||
| 9 | import torch | 9 | import torch |
| 10 | import torch.nn.functional as F | ||
| 11 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
| 12 | 11 | ||
| 13 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
| @@ -166,7 +165,7 @@ def parse_args(): | |||
| 166 | parser.add_argument( | 165 | parser.add_argument( |
| 167 | "--tag_dropout", | 166 | "--tag_dropout", |
| 168 | type=float, | 167 | type=float, |
| 169 | default=0.1, | 168 | default=0, |
| 170 | help="Tag dropout probability.", | 169 | help="Tag dropout probability.", |
| 171 | ) | 170 | ) |
| 172 | parser.add_argument( | 171 | parser.add_argument( |
| @@ -177,7 +176,7 @@ def parse_args(): | |||
| 177 | parser.add_argument( | 176 | parser.add_argument( |
| 178 | "--vector_dropout", | 177 | "--vector_dropout", |
| 179 | type=int, | 178 | type=int, |
| 180 | default=0.1, | 179 | default=0, |
| 181 | help="Vector dropout probability.", | 180 | help="Vector dropout probability.", |
| 182 | ) | 181 | ) |
| 183 | parser.add_argument( | 182 | parser.add_argument( |
| @@ -869,11 +868,7 @@ def main(): | |||
| 869 | 868 | ||
| 870 | @torch.no_grad() | 869 | @torch.no_grad() |
| 871 | def on_clip(lr): | 870 | def on_clip(lr): |
| 872 | embeddings = text_encoder.text_model.embeddings.temp_token_embedding | 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) |
| 873 | |||
| 874 | pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) | ||
| 875 | lambda_ = min(1.0, 100 * lr) | ||
| 876 | embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) | ||
| 877 | 872 | ||
| 878 | loop = partial( | 873 | loop = partial( |
| 879 | loss_step, | 874 | loss_step, |
