From dc463a6b8ef120b7a0643569b66f9109ed38c652 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 14:07:44 +0100 Subject: Simplified multi-vector embedding code --- common.py | 3 +-- models/clip/tokenizer.py | 23 +++++++++++------------ train_ti.py | 5 ++--- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/common.py b/common.py index e8d3ac1..1e7f4b9 100644 --- a/common.py +++ b/common.py @@ -32,7 +32,6 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC embed = file.get_tensor("embed") added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) - embeddings.add_embed(added.placeholder_id) - embeddings.add_embed(added.multi_ids, embed) + embeddings.add_embed(added.ids, embed) return tokens diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 78871db..7e08287 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -8,8 +8,8 @@ from transformers import CLIPTokenizer class MultiCLIPTokenizerItem(NamedTuple): token: str - placeholder_id: int - multi_ids: list[int] + meta_id: int + ids: list[int] class MultiCLIPTokenizer(CLIPTokenizer): @@ -30,20 +30,19 @@ class MultiCLIPTokenizer(CLIPTokenizer): if isinstance(num_vectors, list): raise ValueError("Expected num_vectors to be int for single token") - super().add_tokens(new_tokens) + if num_vectors < 1: + raise ValueError("Expected num_vectors to be >= 1") - if num_vectors == 1: - multi_token = [new_tokens] - else: - multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] - super().add_tokens(multi_token) + multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] - meta_id = super().convert_tokens_to_ids(new_tokens) - multi_ids = super().convert_tokens_to_ids(multi_token) + super().add_tokens(multi_token) - self.token_map[meta_id] = multi_ids + ids = super().convert_tokens_to_ids(multi_token) + meta_id = ids[0] - return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) + self.token_map[meta_id] = ids + + return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) def encode(self, *args, vector_shuffle=True, **kwargs): ids = super().encode(*args, **kwargs) diff --git a/train_ti.py b/train_ti.py index 3a5cfed..3776eb2 100644 --- a/train_ti.py +++ b/train_ti.py @@ -438,7 +438,7 @@ class Checkpointer(CheckpointerBase): for new_token in self.new_tokens: text_encoder.text_model.embeddings.save_embed( - new_token.multi_ids, + new_token.ids, checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") ) @@ -537,8 +537,7 @@ def main(): new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): - embeddings.add_embed(new_token.placeholder_id) - embeddings.add_embed(new_token.multi_ids, init_ids) + embeddings.add_embed(new_token.ids, init_ids) print(f"Added {len(new_tokens)} new tokens.") -- cgit v1.2.3-54-g00ecf