From 6c64f769043c8212b1a5778e857af691a828798d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 10:19:38 +0100 Subject: Various cleanups --- models/clip/embeddings.py | 5 +++++ models/clip/tokenizer.py | 9 ++------- 2 files changed, 7 insertions(+), 7 deletions(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 1280ebd..fb639f1 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -53,6 +53,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): + init_ratio = 1.0 + if isinstance(token_ids, int): token_ids = [token_ids] @@ -63,6 +65,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): initializer = [initializer] if isinstance(initializer, list): + init_ratio = len(initializer) / len(token_ids) initializer = (initializer * len(token_ids))[:len(token_ids)] with torch.no_grad(): @@ -76,6 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): dtype=self.temp_token_embedding.weight.dtype, ) + return init_ratio + def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: self.add_embed(input_ids, file.get_tensor("embed")) diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 4e97ab5..034adf9 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -55,11 +55,6 @@ def shuffle_auto(tokens: list[int]): return shuffle_all(tokens) -class MultiCLIPTokenizerItem(NamedTuple): - token: str - ids: list[int] - - class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -96,7 +91,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1 - ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]: + ) -> Union[list[int], list[list[int]]]: if isinstance(new_tokens, list): if isinstance(num_vectors, int): num_vectors = [num_vectors] * len(new_tokens) @@ -119,7 +114,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): self.token_map[ids[0]] = ids - return MultiCLIPTokenizerItem(new_tokens, ids) + return ids def expand_id(self, id: int): if id in self.token_map: -- cgit v1.2.3-54-g00ecf