diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
commit | 6c64f769043c8212b1a5778e857af691a828798d (patch) | |
tree | fe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /models/clip | |
parent | Update (diff) | |
download | textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2 textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip |
Various cleanups
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 5 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 9 |
2 files changed, 7 insertions, 7 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 1280ebd..fb639f1 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -53,6 +53,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
54 | 54 | ||
55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
56 | init_ratio = 1.0 | ||
57 | |||
56 | if isinstance(token_ids, int): | 58 | if isinstance(token_ids, int): |
57 | token_ids = [token_ids] | 59 | token_ids = [token_ids] |
58 | 60 | ||
@@ -63,6 +65,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
63 | initializer = [initializer] | 65 | initializer = [initializer] |
64 | 66 | ||
65 | if isinstance(initializer, list): | 67 | if isinstance(initializer, list): |
68 | init_ratio = len(initializer) / len(token_ids) | ||
66 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 69 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
67 | 70 | ||
68 | with torch.no_grad(): | 71 | with torch.no_grad(): |
@@ -76,6 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
76 | dtype=self.temp_token_embedding.weight.dtype, | 79 | dtype=self.temp_token_embedding.weight.dtype, |
77 | ) | 80 | ) |
78 | 81 | ||
82 | return init_ratio | ||
83 | |||
79 | def load_embed(self, input_ids: list[int], filename: Path): | 84 | def load_embed(self, input_ids: list[int], filename: Path): |
80 | with safe_open(filename, framework="pt", device="cpu") as file: | 85 | with safe_open(filename, framework="pt", device="cpu") as file: |
81 | self.add_embed(input_ids, file.get_tensor("embed")) | 86 | self.add_embed(input_ids, file.get_tensor("embed")) |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 4e97ab5..034adf9 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -55,11 +55,6 @@ def shuffle_auto(tokens: list[int]): | |||
55 | return shuffle_all(tokens) | 55 | return shuffle_all(tokens) |
56 | 56 | ||
57 | 57 | ||
58 | class MultiCLIPTokenizerItem(NamedTuple): | ||
59 | token: str | ||
60 | ids: list[int] | ||
61 | |||
62 | |||
63 | class MultiCLIPTokenizer(CLIPTokenizer): | 58 | class MultiCLIPTokenizer(CLIPTokenizer): |
64 | def __init__(self, *args, **kwargs): | 59 | def __init__(self, *args, **kwargs): |
65 | super().__init__(*args, **kwargs) | 60 | super().__init__(*args, **kwargs) |
@@ -96,7 +91,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
96 | self, | 91 | self, |
97 | new_tokens: Union[str, list[str]], | 92 | new_tokens: Union[str, list[str]], |
98 | num_vectors: Union[int, list[int]] = 1 | 93 | num_vectors: Union[int, list[int]] = 1 |
99 | ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]: | 94 | ) -> Union[list[int], list[list[int]]]: |
100 | if isinstance(new_tokens, list): | 95 | if isinstance(new_tokens, list): |
101 | if isinstance(num_vectors, int): | 96 | if isinstance(num_vectors, int): |
102 | num_vectors = [num_vectors] * len(new_tokens) | 97 | num_vectors = [num_vectors] * len(new_tokens) |
@@ -119,7 +114,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
119 | 114 | ||
120 | self.token_map[ids[0]] = ids | 115 | self.token_map[ids[0]] = ids |
121 | 116 | ||
122 | return MultiCLIPTokenizerItem(new_tokens, ids) | 117 | return ids |
123 | 118 | ||
124 | def expand_id(self, id: int): | 119 | def expand_id(self, id: int): |
125 | if id in self.token_map: | 120 | if id in self.token_map: |