diff options
| -rw-r--r-- | models/clip/embeddings.py | 4 | ||||
| -rw-r--r-- | models/sparse.py | 11 | ||||
| -rw-r--r-- | train_dreambooth.py | 4 | ||||
| -rw-r--r-- | train_lora.py | 6 | ||||
| -rw-r--r-- | train_ti.py | 6 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 3 |
6 files changed, 20 insertions, 14 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8c3c6d4..afb7430 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 79 | def save_embed(self, input_ids: list[int], filename: Path): | 79 | def save_embed(self, input_ids: list[int], filename: Path): |
| 80 | save_file({"embed": self.get_embed(input_ids)}, filename) | 80 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 81 | 81 | ||
| 82 | def persist(self): | 82 | def persist(self, clear=False): |
| 83 | self.token_embedding.persist() | 83 | self.token_embedding.persist(clear) |
| 84 | 84 | ||
| 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 86 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |
diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): | |||
| 89 | 89 | ||
| 90 | return weights | 90 | return weights |
| 91 | 91 | ||
| 92 | def persist(self): | 92 | def persist(self, clear=False): |
| 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 94 | self.trainable_ids[:] = -1 | 94 | |
| 95 | self.trainable = nn.ParameterList() | 95 | if clear: |
| 96 | self.trainable_ids[:] = -1 | ||
| 97 | self.trainable = nn.ParameterList() | ||
| 98 | else: | ||
| 99 | for param in self.trainable: | ||
| 100 | param.zero_() | ||
| 96 | 101 | ||
| 97 | def reset_parameters(self): | 102 | def reset_parameters(self): |
| 98 | nn.Embedding.reset_parameters(self) | 103 | nn.Embedding.reset_parameters(self) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index beb65fc..929310b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -661,7 +661,7 @@ def main(): | |||
| 661 | placeholder_tokens=alias_placeholder_tokens, | 661 | placeholder_tokens=alias_placeholder_tokens, |
| 662 | initializer_tokens=alias_initializer_tokens, | 662 | initializer_tokens=alias_initializer_tokens, |
| 663 | ) | 663 | ) |
| 664 | embeddings.persist() | 664 | embeddings.persist(True) |
| 665 | print( | 665 | print( |
| 666 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | 666 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" |
| 667 | ) | 667 | ) |
| @@ -682,7 +682,7 @@ def main(): | |||
| 682 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | 682 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" |
| 683 | ) | 683 | ) |
| 684 | 684 | ||
| 685 | embeddings.persist() | 685 | embeddings.persist(True) |
| 686 | 686 | ||
| 687 | if len(args.placeholder_tokens) != 0: | 687 | if len(args.placeholder_tokens) != 0: |
| 688 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 688 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
diff --git a/train_lora.py b/train_lora.py index 2a43252..eeac81f 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -777,7 +777,7 @@ def main(): | |||
| 777 | placeholder_tokens=alias_placeholder_tokens, | 777 | placeholder_tokens=alias_placeholder_tokens, |
| 778 | initializer_tokens=alias_initializer_tokens, | 778 | initializer_tokens=alias_initializer_tokens, |
| 779 | ) | 779 | ) |
| 780 | embeddings.persist() | 780 | embeddings.persist(True) |
| 781 | print( | 781 | print( |
| 782 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | 782 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" |
| 783 | ) | 783 | ) |
| @@ -806,7 +806,7 @@ def main(): | |||
| 806 | if args.train_dir_embeddings: | 806 | if args.train_dir_embeddings: |
| 807 | print("Training embeddings from embeddings dir") | 807 | print("Training embeddings from embeddings dir") |
| 808 | else: | 808 | else: |
| 809 | embeddings.persist() | 809 | embeddings.persist(True) |
| 810 | 810 | ||
| 811 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 811 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
| 812 | embeddings = ensure_embeddings() | 812 | embeddings = ensure_embeddings() |
| @@ -1117,7 +1117,7 @@ def main(): | |||
| 1117 | no_val=True, | 1117 | no_val=True, |
| 1118 | ) | 1118 | ) |
| 1119 | 1119 | ||
| 1120 | embeddings.persist() | 1120 | embeddings.persist(True) |
| 1121 | 1121 | ||
| 1122 | # LORA | 1122 | # LORA |
| 1123 | # -------------------------------------------------------------------------------- | 1123 | # -------------------------------------------------------------------------------- |
diff --git a/train_ti.py b/train_ti.py index 89f4113..1d0cb6f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -691,7 +691,7 @@ def main(): | |||
| 691 | placeholder_tokens=alias_placeholder_tokens, | 691 | placeholder_tokens=alias_placeholder_tokens, |
| 692 | initializer_tokens=alias_initializer_tokens, | 692 | initializer_tokens=alias_initializer_tokens, |
| 693 | ) | 693 | ) |
| 694 | embeddings.persist() | 694 | embeddings.persist(True) |
| 695 | print( | 695 | print( |
| 696 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | 696 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" |
| 697 | ) | 697 | ) |
| @@ -712,7 +712,7 @@ def main(): | |||
| 712 | args.placeholder_tokens = added_tokens | 712 | args.placeholder_tokens = added_tokens |
| 713 | print("Training embeddings from embeddings dir") | 713 | print("Training embeddings from embeddings dir") |
| 714 | else: | 714 | else: |
| 715 | embeddings.persist() | 715 | embeddings.persist(True) |
| 716 | 716 | ||
| 717 | if args.scale_lr: | 717 | if args.scale_lr: |
| 718 | args.learning_rate = ( | 718 | args.learning_rate = ( |
| @@ -1067,7 +1067,7 @@ def main(): | |||
| 1067 | args.train_data_template, | 1067 | args.train_data_template, |
| 1068 | ): | 1068 | ): |
| 1069 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) | 1069 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) |
| 1070 | embeddings.persist() | 1070 | embeddings.persist(True) |
| 1071 | 1071 | ||
| 1072 | 1072 | ||
| 1073 | if __name__ == "__main__": | 1073 | if __name__ == "__main__": |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index bd853e2..3d1abf7 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -98,7 +98,6 @@ def dreambooth_strategy_callbacks( | |||
| 98 | 98 | ||
| 99 | if cycle < train_text_encoder_cycles: | 99 | if cycle < train_text_encoder_cycles: |
| 100 | text_encoder.train() | 100 | text_encoder.train() |
| 101 | tokenizer.train() | ||
| 102 | 101 | ||
| 103 | yield | 102 | yield |
| 104 | 103 | ||
| @@ -155,6 +154,8 @@ def dreambooth_strategy_callbacks( | |||
| 155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 157 | 156 | ||
| 157 | text_encoder_.text_model.embeddings.persist(False) | ||
| 158 | |||
| 158 | with ema_context(): | 159 | with ema_context(): |
| 159 | pipeline = VlpnStableDiffusion( | 160 | pipeline = VlpnStableDiffusion( |
| 160 | text_encoder=text_encoder_, | 161 | text_encoder=text_encoder_, |
