summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py5
-rw-r--r--train_dreambooth.py6
-rw-r--r--train_ti.py6
3 files changed, 6 insertions, 11 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 384c795..9d8f770 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -53,8 +53,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
56 init_ratio = 1.0
57
58 if isinstance(token_ids, int): 56 if isinstance(token_ids, int):
59 token_ids = [token_ids] 57 token_ids = [token_ids]
60 58
@@ -65,7 +63,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
65 initializer = [initializer] 63 initializer = [initializer]
66 64
67 if isinstance(initializer, list): 65 if isinstance(initializer, list):
68 init_ratio = len(initializer) / len(token_ids)
69 initializer = (initializer * len(token_ids))[:len(token_ids)] 66 initializer = (initializer * len(token_ids))[:len(token_ids)]
70 67
71 with torch.no_grad(): 68 with torch.no_grad():
@@ -79,8 +76,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
79 dtype=self.temp_token_embedding.weight.dtype, 76 dtype=self.temp_token_embedding.weight.dtype,
80 ) 77 )
81 78
82 return init_ratio
83
84 def load_embed(self, input_ids: list[int], filename: Path): 79 def load_embed(self, input_ids: list[int], filename: Path):
85 with safe_open(filename, framework="pt", device="cpu") as file: 80 with safe_open(filename, framework="pt", device="cpu") as file:
86 self.add_embed(input_ids, file.get_tensor("embed")) 81 self.add_embed(input_ids, file.get_tensor("embed"))
diff --git a/train_dreambooth.py b/train_dreambooth.py
index c355ea1..e8256be 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -624,10 +624,10 @@ def main():
624 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) 624 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
625 embeddings.resize(len(tokenizer)) 625 embeddings.resize(len(tokenizer))
626 626
627 init_ratios = [ 627 for (new_id, init_ids) in zip(new_ids, initializer_token_ids):
628 embeddings.add_embed(new_id, init_ids) 628 embeddings.add_embed(new_id, init_ids)
629 for (new_id, init_ids) in zip(new_ids, initializer_token_ids) 629
630 ] 630 init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)]
631 631
632 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") 632 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}")
633 else: 633 else:
diff --git a/train_ti.py b/train_ti.py
index 1b8c597..0ffc9e6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -607,10 +607,10 @@ def main():
607 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) 607 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
608 embeddings.resize(len(tokenizer)) 608 embeddings.resize(len(tokenizer))
609 609
610 init_ratios = [ 610 for (new_id, init_ids) in zip(new_ids, initializer_token_ids):
611 embeddings.add_embed(new_id, init_ids) 611 embeddings.add_embed(new_id, init_ids)
612 for (new_id, init_ids) in zip(new_ids, initializer_token_ids) 612
613 ] 613 init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)]
614 614
615 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") 615 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}")
616 616