diff options
-rw-r--r-- | models/clip/embeddings.py | 5 | ||||
-rw-r--r-- | train_dreambooth.py | 6 | ||||
-rw-r--r-- | train_ti.py | 6 |
3 files changed, 6 insertions, 11 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 384c795..9d8f770 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -53,8 +53,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
54 | 54 | ||
55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
56 | init_ratio = 1.0 | ||
57 | |||
58 | if isinstance(token_ids, int): | 56 | if isinstance(token_ids, int): |
59 | token_ids = [token_ids] | 57 | token_ids = [token_ids] |
60 | 58 | ||
@@ -65,7 +63,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
65 | initializer = [initializer] | 63 | initializer = [initializer] |
66 | 64 | ||
67 | if isinstance(initializer, list): | 65 | if isinstance(initializer, list): |
68 | init_ratio = len(initializer) / len(token_ids) | ||
69 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 66 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
70 | 67 | ||
71 | with torch.no_grad(): | 68 | with torch.no_grad(): |
@@ -79,8 +76,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
79 | dtype=self.temp_token_embedding.weight.dtype, | 76 | dtype=self.temp_token_embedding.weight.dtype, |
80 | ) | 77 | ) |
81 | 78 | ||
82 | return init_ratio | ||
83 | |||
84 | def load_embed(self, input_ids: list[int], filename: Path): | 79 | def load_embed(self, input_ids: list[int], filename: Path): |
85 | with safe_open(filename, framework="pt", device="cpu") as file: | 80 | with safe_open(filename, framework="pt", device="cpu") as file: |
86 | self.add_embed(input_ids, file.get_tensor("embed")) | 81 | self.add_embed(input_ids, file.get_tensor("embed")) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index c355ea1..e8256be 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -624,10 +624,10 @@ def main(): | |||
624 | new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 624 | new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
625 | embeddings.resize(len(tokenizer)) | 625 | embeddings.resize(len(tokenizer)) |
626 | 626 | ||
627 | init_ratios = [ | 627 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids): |
628 | embeddings.add_embed(new_id, init_ids) | 628 | embeddings.add_embed(new_id, init_ids) |
629 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | 629 | |
630 | ] | 630 | init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] |
631 | 631 | ||
632 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") | 632 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") |
633 | else: | 633 | else: |
diff --git a/train_ti.py b/train_ti.py index 1b8c597..0ffc9e6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -607,10 +607,10 @@ def main(): | |||
607 | new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 607 | new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
608 | embeddings.resize(len(tokenizer)) | 608 | embeddings.resize(len(tokenizer)) |
609 | 609 | ||
610 | init_ratios = [ | 610 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids): |
611 | embeddings.add_embed(new_id, init_ids) | 611 | embeddings.add_embed(new_id, init_ids) |
612 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | 612 | |
613 | ] | 613 | init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] |
614 | 614 | ||
615 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") | 615 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") |
616 | 616 | ||