diff options
-rw-r--r-- | models/clip/embeddings.py | 4 | ||||
-rw-r--r-- | models/sparse.py | 11 | ||||
-rw-r--r-- | train_dreambooth.py | 4 | ||||
-rw-r--r-- | train_lora.py | 6 | ||||
-rw-r--r-- | train_ti.py | 6 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 3 |
6 files changed, 20 insertions, 14 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8c3c6d4..afb7430 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
79 | def save_embed(self, input_ids: list[int], filename: Path): | 79 | def save_embed(self, input_ids: list[int], filename: Path): |
80 | save_file({"embed": self.get_embed(input_ids)}, filename) | 80 | save_file({"embed": self.get_embed(input_ids)}, filename) |
81 | 81 | ||
82 | def persist(self): | 82 | def persist(self, clear=False): |
83 | self.token_embedding.persist() | 83 | self.token_embedding.persist(clear) |
84 | 84 | ||
85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
86 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |
diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
@@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): | |||
89 | 89 | ||
90 | return weights | 90 | return weights |
91 | 91 | ||
92 | def persist(self): | 92 | def persist(self, clear=False): |
93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
94 | self.trainable_ids[:] = -1 | 94 | |
95 | self.trainable = nn.ParameterList() | 95 | if clear: |
96 | self.trainable_ids[:] = -1 | ||
97 | self.trainable = nn.ParameterList() | ||
98 | else: | ||
99 | for param in self.trainable: | ||
100 | param.zero_() | ||
96 | 101 | ||
97 | def reset_parameters(self): | 102 | def reset_parameters(self): |
98 | nn.Embedding.reset_parameters(self) | 103 | nn.Embedding.reset_parameters(self) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index beb65fc..929310b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -661,7 +661,7 @@ def main(): | |||
661 | placeholder_tokens=alias_placeholder_tokens, | 661 | placeholder_tokens=alias_placeholder_tokens, |
662 | initializer_tokens=alias_initializer_tokens, | 662 | initializer_tokens=alias_initializer_tokens, |
663 | ) | 663 | ) |
664 | embeddings.persist() | 664 | embeddings.persist(True) |
665 | print( | 665 | print( |
666 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | 666 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" |
667 | ) | 667 | ) |
@@ -682,7 +682,7 @@ def main(): | |||
682 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | 682 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" |
683 | ) | 683 | ) |
684 | 684 | ||
685 | embeddings.persist() | 685 | embeddings.persist(True) |
686 | 686 | ||
687 | if len(args.placeholder_tokens) != 0: | 687 | if len(args.placeholder_tokens) != 0: |
688 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 688 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
diff --git a/train_lora.py b/train_lora.py index 2a43252..eeac81f 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -777,7 +777,7 @@ def main(): | |||
777 | placeholder_tokens=alias_placeholder_tokens, | 777 | placeholder_tokens=alias_placeholder_tokens, |
778 | initializer_tokens=alias_initializer_tokens, | 778 | initializer_tokens=alias_initializer_tokens, |
779 | ) | 779 | ) |
780 | embeddings.persist() | 780 | embeddings.persist(True) |
781 | print( | 781 | print( |
782 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | 782 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" |
783 | ) | 783 | ) |
@@ -806,7 +806,7 @@ def main(): | |||
806 | if args.train_dir_embeddings: | 806 | if args.train_dir_embeddings: |
807 | print("Training embeddings from embeddings dir") | 807 | print("Training embeddings from embeddings dir") |
808 | else: | 808 | else: |
809 | embeddings.persist() | 809 | embeddings.persist(True) |
810 | 810 | ||
811 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 811 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
812 | embeddings = ensure_embeddings() | 812 | embeddings = ensure_embeddings() |
@@ -1117,7 +1117,7 @@ def main(): | |||
1117 | no_val=True, | 1117 | no_val=True, |
1118 | ) | 1118 | ) |
1119 | 1119 | ||
1120 | embeddings.persist() | 1120 | embeddings.persist(True) |
1121 | 1121 | ||
1122 | # LORA | 1122 | # LORA |
1123 | # -------------------------------------------------------------------------------- | 1123 | # -------------------------------------------------------------------------------- |
diff --git a/train_ti.py b/train_ti.py index 89f4113..1d0cb6f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -691,7 +691,7 @@ def main(): | |||
691 | placeholder_tokens=alias_placeholder_tokens, | 691 | placeholder_tokens=alias_placeholder_tokens, |
692 | initializer_tokens=alias_initializer_tokens, | 692 | initializer_tokens=alias_initializer_tokens, |
693 | ) | 693 | ) |
694 | embeddings.persist() | 694 | embeddings.persist(True) |
695 | print( | 695 | print( |
696 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | 696 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" |
697 | ) | 697 | ) |
@@ -712,7 +712,7 @@ def main(): | |||
712 | args.placeholder_tokens = added_tokens | 712 | args.placeholder_tokens = added_tokens |
713 | print("Training embeddings from embeddings dir") | 713 | print("Training embeddings from embeddings dir") |
714 | else: | 714 | else: |
715 | embeddings.persist() | 715 | embeddings.persist(True) |
716 | 716 | ||
717 | if args.scale_lr: | 717 | if args.scale_lr: |
718 | args.learning_rate = ( | 718 | args.learning_rate = ( |
@@ -1067,7 +1067,7 @@ def main(): | |||
1067 | args.train_data_template, | 1067 | args.train_data_template, |
1068 | ): | 1068 | ): |
1069 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) | 1069 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) |
1070 | embeddings.persist() | 1070 | embeddings.persist(True) |
1071 | 1071 | ||
1072 | 1072 | ||
1073 | if __name__ == "__main__": | 1073 | if __name__ == "__main__": |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index bd853e2..3d1abf7 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -98,7 +98,6 @@ def dreambooth_strategy_callbacks( | |||
98 | 98 | ||
99 | if cycle < train_text_encoder_cycles: | 99 | if cycle < train_text_encoder_cycles: |
100 | text_encoder.train() | 100 | text_encoder.train() |
101 | tokenizer.train() | ||
102 | 101 | ||
103 | yield | 102 | yield |
104 | 103 | ||
@@ -155,6 +154,8 @@ def dreambooth_strategy_callbacks( | |||
155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
157 | 156 | ||
157 | text_encoder_.text_model.embeddings.persist(False) | ||
158 | |||
158 | with ema_context(): | 159 | with ema_context(): |
159 | pipeline = VlpnStableDiffusion( | 160 | pipeline = VlpnStableDiffusion( |
160 | text_encoder=text_encoder_, | 161 | text_encoder=text_encoder_, |