From adc52fb8821a496bc8d78235bf10466b39df03e0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 19:19:52 +0100 Subject: Updates --- models/clip/embeddings.py | 11 +++++++ models/clip/tokenizer.py | 76 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 71 insertions(+), 16 deletions(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..8602142 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -120,3 +120,14 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) text_encoder.text_model.embeddings = text_embeddings return text_embeddings + + +def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings: + text_encoder.text_model.embeddings.make_permanent() + + text_embeddings = CLIPTextEmbeddings(text_encoder.config) + text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding + text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding + text_encoder.text_model.embeddings = text_embeddings + + return text_embeddings diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 37d69a9..ed9774e 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -1,11 +1,54 @@ import copy -from typing import NamedTuple, Union +from typing import NamedTuple, Union, Literal import numpy as np from transformers import CLIPTokenizer +def shuffle_all(tokens: list[int]): + if len(tokens) >= 2: + tokens = copy.copy(tokens) + np.random.shuffle(tokens) + return tokens + + +def shuffle_leading(tokens: list[int]): + if len(tokens) >= 3: + subtokens = tokens[:-1] + np.random.shuffle(subtokens) + tokens = subtokens + tokens[-1:] + return tokens + + +def shuffle_trailing(tokens: list[int]): + if len(tokens) >= 3: + subtokens = tokens[1:] + np.random.shuffle(subtokens) + tokens = tokens[:1] + subtokens + return tokens + + +def shuffle_between(tokens: list[int]): + if len(tokens) >= 4: + subtokens = tokens[1:-1] + np.random.shuffle(subtokens) + tokens = tokens[:1] + subtokens + tokens[-1:] + return tokens + + +def shuffle_none(tokens: list[int]): + return tokens + + +def shuffle_auto(tokens: list[int]): + if len(tokens) >= 4: + return shuffle_between(tokens) + if len(tokens) >= 3: + return shuffle_trailing(tokens) + return shuffle_all(tokens) + + class MultiCLIPTokenizerItem(NamedTuple): token: str ids: list[int] @@ -15,10 +58,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.token_map: dict[int, list[int]] = {} - self.vector_shuffle = False - - def set_use_vector_shuffle(self, enable: bool): - self.vector_shuffle = enable + self.vector_shuffle = shuffle_none + + def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): + if algorithm == "leading": + self.vector_shuffle = shuffle_leading + elif algorithm == "trailing": + self.vector_shuffle = shuffle_trailing + elif algorithm == "between": + self.vector_shuffle = shuffle_between + elif algorithm == "auto": + self.vector_shuffle = shuffle_auto + elif algorithm == True or algorithm == "all": + self.vector_shuffle = shuffle_all + else: + self.vector_shuffle = shuffle_none def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: if isinstance(new_tokens, list): @@ -47,17 +101,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): return MultiCLIPTokenizerItem(new_tokens, ids) def expand_id(self, id: int): - if id in self.token_map: - tokens = self.token_map[id] - - if self.vector_shuffle and len(tokens) > 2: - subtokens = tokens[1:-1] - np.random.shuffle(subtokens) - tokens = tokens[:1] + subtokens + tokens[-1:] - - return tokens - else: - return [id] + return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] def expand_ids(self, ids: list[int]): return [ -- cgit v1.2.3-70-g09d2