diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/embeddings.py | 5 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 9 |
2 files changed, 7 insertions, 7 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 1280ebd..fb639f1 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -53,6 +53,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 54 | 54 | ||
| 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
| 56 | init_ratio = 1.0 | ||
| 57 | |||
| 56 | if isinstance(token_ids, int): | 58 | if isinstance(token_ids, int): |
| 57 | token_ids = [token_ids] | 59 | token_ids = [token_ids] |
| 58 | 60 | ||
| @@ -63,6 +65,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 63 | initializer = [initializer] | 65 | initializer = [initializer] |
| 64 | 66 | ||
| 65 | if isinstance(initializer, list): | 67 | if isinstance(initializer, list): |
| 68 | init_ratio = len(initializer) / len(token_ids) | ||
| 66 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 69 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
| 67 | 70 | ||
| 68 | with torch.no_grad(): | 71 | with torch.no_grad(): |
| @@ -76,6 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 76 | dtype=self.temp_token_embedding.weight.dtype, | 79 | dtype=self.temp_token_embedding.weight.dtype, |
| 77 | ) | 80 | ) |
| 78 | 81 | ||
| 82 | return init_ratio | ||
| 83 | |||
| 79 | def load_embed(self, input_ids: list[int], filename: Path): | 84 | def load_embed(self, input_ids: list[int], filename: Path): |
| 80 | with safe_open(filename, framework="pt", device="cpu") as file: | 85 | with safe_open(filename, framework="pt", device="cpu") as file: |
| 81 | self.add_embed(input_ids, file.get_tensor("embed")) | 86 | self.add_embed(input_ids, file.get_tensor("embed")) |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 4e97ab5..034adf9 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -55,11 +55,6 @@ def shuffle_auto(tokens: list[int]): | |||
| 55 | return shuffle_all(tokens) | 55 | return shuffle_all(tokens) |
| 56 | 56 | ||
| 57 | 57 | ||
| 58 | class MultiCLIPTokenizerItem(NamedTuple): | ||
| 59 | token: str | ||
| 60 | ids: list[int] | ||
| 61 | |||
| 62 | |||
| 63 | class MultiCLIPTokenizer(CLIPTokenizer): | 58 | class MultiCLIPTokenizer(CLIPTokenizer): |
| 64 | def __init__(self, *args, **kwargs): | 59 | def __init__(self, *args, **kwargs): |
| 65 | super().__init__(*args, **kwargs) | 60 | super().__init__(*args, **kwargs) |
| @@ -96,7 +91,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 96 | self, | 91 | self, |
| 97 | new_tokens: Union[str, list[str]], | 92 | new_tokens: Union[str, list[str]], |
| 98 | num_vectors: Union[int, list[int]] = 1 | 93 | num_vectors: Union[int, list[int]] = 1 |
| 99 | ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]: | 94 | ) -> Union[list[int], list[list[int]]]: |
| 100 | if isinstance(new_tokens, list): | 95 | if isinstance(new_tokens, list): |
| 101 | if isinstance(num_vectors, int): | 96 | if isinstance(num_vectors, int): |
| 102 | num_vectors = [num_vectors] * len(new_tokens) | 97 | num_vectors = [num_vectors] * len(new_tokens) |
| @@ -119,7 +114,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 119 | 114 | ||
| 120 | self.token_map[ids[0]] = ids | 115 | self.token_map[ids[0]] = ids |
| 121 | 116 | ||
| 122 | return MultiCLIPTokenizerItem(new_tokens, ids) | 117 | return ids |
| 123 | 118 | ||
| 124 | def expand_id(self, id: int): | 119 | def expand_id(self, id: int): |
| 125 | if id in self.token_map: | 120 | if id in self.token_map: |
