From a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 22:06:05 +0100 Subject: Update --- models/clip/embeddings.py | 4 +++- models/clip/tokenizer.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9c3a56b..1280ebd 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -72,7 +72,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) self.temp_token_embedding.weight.data[token_ids] = initializer.to( - dtype=self.temp_token_embedding.weight.dtype) + device=self.temp_token_embedding.weight.device, + dtype=self.temp_token_embedding.weight.dtype, + ) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 11a3df0..4e97ab5 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -48,7 +48,7 @@ def shuffle_none(tokens: list[int]): def shuffle_auto(tokens: list[int]): - if len(tokens) >= 4: + if len(tokens) >= 5: return shuffle_between(tokens) if len(tokens) >= 3: return shuffle_trailing(tokens) -- cgit v1.2.3-54-g00ecf