From a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Wed, 4 Jan 2023 22:06:05 +0100
Subject: Update

---
 models/clip/embeddings.py | 4 +++-
 models/clip/tokenizer.py  | 2 +-
 2 files changed, 4 insertions(+), 2 deletions(-)

(limited to 'models/clip')

diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9c3a56b..1280ebd 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -72,7 +72,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
 
         self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
         self.temp_token_embedding.weight.data[token_ids] = initializer.to(
-            dtype=self.temp_token_embedding.weight.dtype)
+            device=self.temp_token_embedding.weight.device,
+            dtype=self.temp_token_embedding.weight.dtype,
+        )
 
     def load_embed(self, input_ids: list[int], filename: Path):
         with safe_open(filename, framework="pt", device="cpu") as file:
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 11a3df0..4e97ab5 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -48,7 +48,7 @@ def shuffle_none(tokens: list[int]):
 
 
 def shuffle_auto(tokens: list[int]):
-    if len(tokens) >= 4:
+    if len(tokens) >= 5:
         return shuffle_between(tokens)
     if len(tokens) >= 3:
         return shuffle_trailing(tokens)
-- 
cgit v1.2.3-70-g09d2