summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-24 21:00:29 +0200
committerVolpeon <git@volpeon.ink>2023-06-24 21:00:29 +0200
commit12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 (patch)
treeb0fcf8ad1d26c40d784ddc154622f6d01ecac082
parentNew loss scaling (diff)
downloadtextual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.gz
textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.bz2
textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.zip
Update
-rw-r--r--models/clip/embeddings.py4
-rw-r--r--models/sparse.py11
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_lora.py6
-rw-r--r--train_ti.py6
-rw-r--r--training/strategy/dreambooth.py3
6 files changed, 20 insertions, 14 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 8c3c6d4..afb7430 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
79 def save_embed(self, input_ids: list[int], filename: Path): 79 def save_embed(self, input_ids: list[int], filename: Path):
80 save_file({"embed": self.get_embed(input_ids)}, filename) 80 save_file({"embed": self.get_embed(input_ids)}, filename)
81 81
82 def persist(self): 82 def persist(self, clear=False):
83 self.token_embedding.persist() 83 self.token_embedding.persist(clear)
84 84
85 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 85 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
86 if isinstance(input_ids, list): 86 if isinstance(input_ids, list):
diff --git a/models/sparse.py b/models/sparse.py
index e5897c9..55c9837 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding):
89 89
90 return weights 90 return weights
91 91
92 def persist(self): 92 def persist(self, clear=False):
93 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) 93 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
94 self.trainable_ids[:] = -1 94
95 self.trainable = nn.ParameterList() 95 if clear:
96 self.trainable_ids[:] = -1
97 self.trainable = nn.ParameterList()
98 else:
99 for param in self.trainable:
100 param.zero_()
96 101
97 def reset_parameters(self): 102 def reset_parameters(self):
98 nn.Embedding.reset_parameters(self) 103 nn.Embedding.reset_parameters(self)
diff --git a/train_dreambooth.py b/train_dreambooth.py
index beb65fc..929310b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -661,7 +661,7 @@ def main():
661 placeholder_tokens=alias_placeholder_tokens, 661 placeholder_tokens=alias_placeholder_tokens,
662 initializer_tokens=alias_initializer_tokens, 662 initializer_tokens=alias_initializer_tokens,
663 ) 663 )
664 embeddings.persist() 664 embeddings.persist(True)
665 print( 665 print(
666 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" 666 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
667 ) 667 )
@@ -682,7 +682,7 @@ def main():
682 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" 682 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
683 ) 683 )
684 684
685 embeddings.persist() 685 embeddings.persist(True)
686 686
687 if len(args.placeholder_tokens) != 0: 687 if len(args.placeholder_tokens) != 0:
688 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 688 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
diff --git a/train_lora.py b/train_lora.py
index 2a43252..eeac81f 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -777,7 +777,7 @@ def main():
777 placeholder_tokens=alias_placeholder_tokens, 777 placeholder_tokens=alias_placeholder_tokens,
778 initializer_tokens=alias_initializer_tokens, 778 initializer_tokens=alias_initializer_tokens,
779 ) 779 )
780 embeddings.persist() 780 embeddings.persist(True)
781 print( 781 print(
782 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" 782 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
783 ) 783 )
@@ -806,7 +806,7 @@ def main():
806 if args.train_dir_embeddings: 806 if args.train_dir_embeddings:
807 print("Training embeddings from embeddings dir") 807 print("Training embeddings from embeddings dir")
808 else: 808 else:
809 embeddings.persist() 809 embeddings.persist(True)
810 810
811 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: 811 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings:
812 embeddings = ensure_embeddings() 812 embeddings = ensure_embeddings()
@@ -1117,7 +1117,7 @@ def main():
1117 no_val=True, 1117 no_val=True,
1118 ) 1118 )
1119 1119
1120 embeddings.persist() 1120 embeddings.persist(True)
1121 1121
1122 # LORA 1122 # LORA
1123 # -------------------------------------------------------------------------------- 1123 # --------------------------------------------------------------------------------
diff --git a/train_ti.py b/train_ti.py
index 89f4113..1d0cb6f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -691,7 +691,7 @@ def main():
691 placeholder_tokens=alias_placeholder_tokens, 691 placeholder_tokens=alias_placeholder_tokens,
692 initializer_tokens=alias_initializer_tokens, 692 initializer_tokens=alias_initializer_tokens,
693 ) 693 )
694 embeddings.persist() 694 embeddings.persist(True)
695 print( 695 print(
696 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" 696 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
697 ) 697 )
@@ -712,7 +712,7 @@ def main():
712 args.placeholder_tokens = added_tokens 712 args.placeholder_tokens = added_tokens
713 print("Training embeddings from embeddings dir") 713 print("Training embeddings from embeddings dir")
714 else: 714 else:
715 embeddings.persist() 715 embeddings.persist(True)
716 716
717 if args.scale_lr: 717 if args.scale_lr:
718 args.learning_rate = ( 718 args.learning_rate = (
@@ -1067,7 +1067,7 @@ def main():
1067 args.train_data_template, 1067 args.train_data_template,
1068 ): 1068 ):
1069 run(i, [placeholder_token], [initializer_token], num_vectors, data_template) 1069 run(i, [placeholder_token], [initializer_token], num_vectors, data_template)
1070 embeddings.persist() 1070 embeddings.persist(True)
1071 1071
1072 1072
1073if __name__ == "__main__": 1073if __name__ == "__main__":
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index bd853e2..3d1abf7 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -98,7 +98,6 @@ def dreambooth_strategy_callbacks(
98 98
99 if cycle < train_text_encoder_cycles: 99 if cycle < train_text_encoder_cycles:
100 text_encoder.train() 100 text_encoder.train()
101 tokenizer.train()
102 101
103 yield 102 yield
104 103
@@ -155,6 +154,8 @@ def dreambooth_strategy_callbacks(
155 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
156 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
157 156
157 text_encoder_.text_model.embeddings.persist(False)
158
158 with ema_context(): 159 with ema_context():
159 pipeline = VlpnStableDiffusion( 160 pipeline = VlpnStableDiffusion(
160 text_encoder=text_encoder_, 161 text_encoder=text_encoder_,