diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-04 22:06:05 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-04 22:06:05 +0100 |
| commit | a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68 (patch) | |
| tree | 8bd97a745e1113b1035c504ec484e099f878aed0 /models | |
| parent | Various updates (diff) | |
| download | textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.tar.gz textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.tar.bz2 textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.zip | |
Update
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 4 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 2 |
2 files changed, 4 insertions, 2 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9c3a56b..1280ebd 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -72,7 +72,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 72 | 72 | ||
| 73 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 73 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( | 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( |
| 75 | dtype=self.temp_token_embedding.weight.dtype) | 75 | device=self.temp_token_embedding.weight.device, |
| 76 | dtype=self.temp_token_embedding.weight.dtype, | ||
| 77 | ) | ||
| 76 | 78 | ||
| 77 | def load_embed(self, input_ids: list[int], filename: Path): | 79 | def load_embed(self, input_ids: list[int], filename: Path): |
| 78 | with safe_open(filename, framework="pt", device="cpu") as file: | 80 | with safe_open(filename, framework="pt", device="cpu") as file: |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 11a3df0..4e97ab5 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -48,7 +48,7 @@ def shuffle_none(tokens: list[int]): | |||
| 48 | 48 | ||
| 49 | 49 | ||
| 50 | def shuffle_auto(tokens: list[int]): | 50 | def shuffle_auto(tokens: list[int]): |
| 51 | if len(tokens) >= 4: | 51 | if len(tokens) >= 5: |
| 52 | return shuffle_between(tokens) | 52 | return shuffle_between(tokens) |
| 53 | if len(tokens) >= 3: | 53 | if len(tokens) >= 3: |
| 54 | return shuffle_trailing(tokens) | 54 | return shuffle_trailing(tokens) |
