From 5ff238bd5bb422d855d5f0b8c81402e74a9da3cc Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 08:55:31 +0100 Subject: Update --- models/clip/embeddings.py | 5 ----- train_dreambooth.py | 6 +++--- train_ti.py | 6 +++--- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 384c795..9d8f770 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -53,8 +53,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): - init_ratio = 1.0 - if isinstance(token_ids, int): token_ids = [token_ids] @@ -65,7 +63,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): initializer = [initializer] if isinstance(initializer, list): - init_ratio = len(initializer) / len(token_ids) initializer = (initializer * len(token_ids))[:len(token_ids)] with torch.no_grad(): @@ -79,8 +76,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): dtype=self.temp_token_embedding.weight.dtype, ) - return init_ratio - def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: self.add_embed(input_ids, file.get_tensor("embed")) diff --git a/train_dreambooth.py b/train_dreambooth.py index c355ea1..e8256be 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -624,10 +624,10 @@ def main(): new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) embeddings.resize(len(tokenizer)) - init_ratios = [ + for (new_id, init_ids) in zip(new_ids, initializer_token_ids): embeddings.add_embed(new_id, init_ids) - for (new_id, init_ids) in zip(new_ids, initializer_token_ids) - ] + + init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") else: diff --git a/train_ti.py b/train_ti.py index 1b8c597..0ffc9e6 100644 --- a/train_ti.py +++ b/train_ti.py @@ -607,10 +607,10 @@ def main(): new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) embeddings.resize(len(tokenizer)) - init_ratios = [ + for (new_id, init_ids) in zip(new_ids, initializer_token_ids): embeddings.add_embed(new_id, init_ids) - for (new_id, init_ids) in zip(new_ids, initializer_token_ids) - ] + + init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") -- cgit v1.2.3-54-g00ecf