summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 12:31:23 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 12:31:23 +0100
commit94c02f8bbc9c57b4def6ab61f5b670bdcd914cf4 (patch)
tree38fb1e2ab0d7c772f4e33712a2f774188d797319 /models
parentSmaller emb decay (diff)
downloadtextual-inversion-diff-94c02f8bbc9c57b4def6ab61f5b670bdcd914cf4.tar.gz
textual-inversion-diff-94c02f8bbc9c57b4def6ab61f5b670bdcd914cf4.tar.bz2
textual-inversion-diff-94c02f8bbc9c57b4def6ab61f5b670bdcd914cf4.zip
Optimized embedding normalization
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 1cc59d9..6c41c33 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -3,7 +3,6 @@ from pathlib import Path
3 3
4import torch 4import torch
5import torch.nn as nn 5import torch.nn as nn
6import torch.nn.functional as F
7 6
8from safetensors import safe_open 7from safetensors import safe_open
9from safetensors.torch import save_file 8from safetensors.torch import save_file
@@ -104,10 +103,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
104 return 103 return
105 104
106 w = self.temp_token_embedding.weight 105 w = self.temp_token_embedding.weight
107 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) 106 norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
108 w[self.temp_token_ids] = F.normalize( 107 w[self.temp_token_ids].add_((w[self.temp_token_ids] / norm.clamp_min(1e-12)) * lambda_ * (target - norm))
109 w[self.temp_token_ids, :], dim=-1
110 ) * (pre_norm + lambda_ * (target - pre_norm))
111 108
112 def forward( 109 def forward(
113 self, 110 self,