From 1c63552a20f34bccd461ac0dfa46405f853cbc7c Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 27 Mar 2023 11:58:47 +0200
Subject: Fix TI

---
 models/clip/embeddings.py | 34 +++++++++-------------------------
 1 file changed, 9 insertions(+), 25 deletions(-)

(limited to 'models/clip')

diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 2b315c4..2d60c28 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -38,24 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
         self.token_embedding = embeddings.token_embedding
         self.position_embedding = embeddings.position_embedding
         self.initializer_factor = config.initializer_factor
-        self.num_permanent_embeddings = self.token_embedding.num_embeddings
-        self.init_temp_embeddings()
 
-    def init_temp_embeddings(self):
         self.temp_token_embedding = nn.Embedding(
-            0,
+            self.token_embedding.num_embeddings,
             self.token_embedding.embedding_dim,
             device=self.token_embedding.weight.device,
             dtype=self.token_embedding.weight.dtype
         )
+        self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
         self.temp_token_ids = torch.tensor([], dtype=torch.long)
 
     def resize(self, size: int):
-        self.temp_token_embedding = resize_embedding(
-            self.temp_token_embedding,
-            size - self.num_permanent_embeddings,
-            self.initializer_factor
-        )
+        self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
         self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
 
     def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
@@ -75,15 +69,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
                 initializer = self.get_embed(initializer)
 
         initializer = initializer.to(
-            device=self.token_embedding.weight.device,
-            dtype=self.token_embedding.weight.dtype,
+            device=self.temp_token_embedding.weight.device,
+            dtype=self.temp_token_embedding.weight.dtype,
         )
 
         token_ids = torch.tensor(token_ids, dtype=torch.long)
 
         self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
-        mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1)
-        self.temp_token_embedding.weight.data[mask] = initializer
+        self.temp_token_embedding.weight.data[token_ids] = initializer
 
     def load_embed(self, input_ids: list[int], filename: Path):
         with safe_open(filename, framework="pt", device="cpu") as file:
@@ -94,25 +87,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
 
     def persist(self):
         self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
-        self.num_permanent_embeddings = self.token_embedding.num_embeddings
-        self.init_temp_embeddings()
+        self.temp_token_ids = torch.tensor([], dtype=torch.long)
 
     def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
         if isinstance(input_ids, list):
             input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
 
-        all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
-
         embeds = self.token_embedding(input_ids)
 
-        embeds_mask = torch.isin(input_ids, all_temp_token_ids)
-        temp_token_ids = input_ids[embeds_mask]
-
-        temp_token_ids = temp_token_ids.unsqueeze(1)
-        all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
-        temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
-
-        embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids)
+        mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
+        embeds[mask] = self.temp_token_embedding(input_ids)[mask]
 
         return embeds
 
-- 
cgit v1.2.3-70-g09d2