From b2db2b6a7c147cdc2901ece92f3918e5b3c47114 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Sat, 31 Dec 2022 23:35:11 +0100
Subject: Fix

---
 models/clip/embeddings.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

(limited to 'models/clip')

diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index cab1515..f90e7c2 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -37,6 +37,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
 
         self.token_embedding = embeddings.token_embedding
         self.position_embedding = embeddings.position_embedding
+        self.initializer_factor = config.initializer_factor
 
         self.temp_token_embedding = nn.Embedding(
             self.token_embedding.num_embeddings,
@@ -44,12 +45,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
             device=self.token_embedding.weight.device,
             dtype=self.token_embedding.weight.dtype
         )
-        self.temp_token_embedding.weight.data.normal_(mean=0.0, std=config.initializer_factor * 0.02)
+        self.temp_token_embedding.weight.data.normal_(mean=0.0, std=self.initializer_factor * 0.02)
         self.temp_token_ids = torch.tensor([], dtype=torch.long)
 
     def resize(self, size: int):
-        self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.config.initializer_factor)
-        self.token_embedding = resize_embedding(self.token_embedding, size, self.config.initializer_factor)
+        self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
+        self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
 
     def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
         if isinstance(token_ids, int):
@@ -63,14 +64,15 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
                 initializer = (initializer * len(token_ids))[:len(token_ids)]
 
                 with torch.no_grad():
-                    initializer = self.get_embed(initializer).to(dtype=self.temp_token_embedding.weight.dtype)
+                    initializer = self.get_embed(initializer)
 
         token_ids = torch.tensor(token_ids, dtype=torch.long)
 
         self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
 
         if initializer is not None:
-            self.temp_token_embedding.weight.data[token_ids] = initializer
+            self.temp_token_embedding.weight.data[token_ids] = initializer.to(
+                dtype=self.temp_token_embedding.weight.dtype)
 
     def load_embed(self, input_ids: list[int], filename: Path):
         with safe_open(filename, framework="pt", device="cpu") as file:
-- 
cgit v1.2.3-70-g09d2