From 94c02f8bbc9c57b4def6ab61f5b670bdcd914cf4 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 12:31:23 +0100 Subject: Optimized embedding normalization --- models/clip/embeddings.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'models/clip/embeddings.py') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 1cc59d9..6c41c33 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -3,7 +3,6 @@ from pathlib import Path import torch import torch.nn as nn -import torch.nn.functional as F from safetensors import safe_open from safetensors.torch import save_file @@ -104,10 +103,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return w = self.temp_token_embedding.weight - pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) - w[self.temp_token_ids] = F.normalize( - w[self.temp_token_ids, :], dim=-1 - ) * (pre_norm + lambda_ * (target - pre_norm)) + norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) + w[self.temp_token_ids].add_((w[self.temp_token_ids] / norm.clamp_min(1e-12)) * lambda_ * (target - norm)) def forward( self, -- cgit v1.2.3-70-g09d2