diff options
author | Volpeon <git@volpeon.ink> | 2023-01-17 12:31:23 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-17 12:31:23 +0100 |
commit | 94c02f8bbc9c57b4def6ab61f5b670bdcd914cf4 (patch) | |
tree | 38fb1e2ab0d7c772f4e33712a2f774188d797319 /models | |
parent | Smaller emb decay (diff) | |
download | textual-inversion-diff-94c02f8bbc9c57b4def6ab61f5b670bdcd914cf4.tar.gz textual-inversion-diff-94c02f8bbc9c57b4def6ab61f5b670bdcd914cf4.tar.bz2 textual-inversion-diff-94c02f8bbc9c57b4def6ab61f5b670bdcd914cf4.zip |
Optimized embedding normalization
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 1cc59d9..6c41c33 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -3,7 +3,6 @@ from pathlib import Path | |||
3 | 3 | ||
4 | import torch | 4 | import torch |
5 | import torch.nn as nn | 5 | import torch.nn as nn |
6 | import torch.nn.functional as F | ||
7 | 6 | ||
8 | from safetensors import safe_open | 7 | from safetensors import safe_open |
9 | from safetensors.torch import save_file | 8 | from safetensors.torch import save_file |
@@ -104,10 +103,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
104 | return | 103 | return |
105 | 104 | ||
106 | w = self.temp_token_embedding.weight | 105 | w = self.temp_token_embedding.weight |
107 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 106 | norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
108 | w[self.temp_token_ids] = F.normalize( | 107 | w[self.temp_token_ids].add_((w[self.temp_token_ids] / norm.clamp_min(1e-12)) * lambda_ * (target - norm)) |
109 | w[self.temp_token_ids, :], dim=-1 | ||
110 | ) * (pre_norm + lambda_ * (target - pre_norm)) | ||
111 | 108 | ||
112 | def forward( | 109 | def forward( |
113 | self, | 110 | self, |