summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 22:06:05 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 22:06:05 +0100
commita5e45e2c0dab95589e5fbaa4fe87d18484fbbe68 (patch)
tree8bd97a745e1113b1035c504ec484e099f878aed0 /models/clip
parentVarious updates (diff)
downloadtextual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.tar.gz
textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.tar.bz2
textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.zip
Update
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py4
-rw-r--r--models/clip/tokenizer.py2
2 files changed, 4 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9c3a56b..1280ebd 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -72,7 +72,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
72 72
73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
74 self.temp_token_embedding.weight.data[token_ids] = initializer.to( 74 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
75 dtype=self.temp_token_embedding.weight.dtype) 75 device=self.temp_token_embedding.weight.device,
76 dtype=self.temp_token_embedding.weight.dtype,
77 )
76 78
77 def load_embed(self, input_ids: list[int], filename: Path): 79 def load_embed(self, input_ids: list[int], filename: Path):
78 with safe_open(filename, framework="pt", device="cpu") as file: 80 with safe_open(filename, framework="pt", device="cpu") as file:
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 11a3df0..4e97ab5 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -48,7 +48,7 @@ def shuffle_none(tokens: list[int]):
48 48
49 49
50def shuffle_auto(tokens: list[int]): 50def shuffle_auto(tokens: list[int]):
51 if len(tokens) >= 4: 51 if len(tokens) >= 5:
52 return shuffle_between(tokens) 52 return shuffle_between(tokens)
53 if len(tokens) >= 3: 53 if len(tokens) >= 3:
54 return shuffle_trailing(tokens) 54 return shuffle_trailing(tokens)