summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py5
-rw-r--r--train_lora.py21
-rw-r--r--train_ti.py17
-rw-r--r--training/strategy/ti.py2
4 files changed, 33 insertions, 12 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index dc4708a..9be8256 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -97,7 +97,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
97 save_file({"embed": self.get_embed(input_ids)}, filename) 97 save_file({"embed": self.get_embed(input_ids)}, filename)
98 98
99 def persist(self): 99 def persist(self):
100 input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) 100 input_ids = torch.arange(
101 self.token_embedding.num_embeddings,
102 device=self.token_override_embedding.mapping.device
103 )
101 embs, mask = self.token_override_embedding(input_ids) 104 embs, mask = self.token_override_embedding(input_ids)
102 if embs is not None: 105 if embs is not None:
103 input_ids = input_ids[mask] 106 input_ids = input_ids[mask]
diff --git a/train_lora.py b/train_lora.py
index 54c9e7a..e81742a 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -81,6 +81,12 @@ def parse_args():
81 help="The name of the current project.", 81 help="The name of the current project.",
82 ) 82 )
83 parser.add_argument( 83 parser.add_argument(
84 "--auto_cycles",
85 type=int,
86 default=1,
87 help="How many cycles to run automatically."
88 )
89 parser.add_argument(
84 "--placeholder_tokens", 90 "--placeholder_tokens",
85 type=str, 91 type=str,
86 nargs='*', 92 nargs='*',
@@ -933,10 +939,15 @@ def main():
933 train_epochs=num_train_epochs, 939 train_epochs=num_train_epochs,
934 ) 940 )
935 941
936 continue_training = True 942 training_iter = 0
937 training_iter = 1 943
944 while True:
945 training_iter += 1
946 if training_iter > args.auto_cycles:
947 response = input("Run another cycle? [y/n] ")
948 if response.lower().strip() == "n":
949 break
938 950
939 while continue_training:
940 print("") 951 print("")
941 print(f"============ LoRA cycle {training_iter} ============") 952 print(f"============ LoRA cycle {training_iter} ============")
942 print("") 953 print("")
@@ -961,10 +972,6 @@ def main():
961 sample_frequency=lora_sample_frequency, 972 sample_frequency=lora_sample_frequency,
962 ) 973 )
963 974
964 response = input("Run another cycle? [y/n] ")
965 continue_training = response.lower().strip() != "n"
966 training_iter += 1
967
968 975
969if __name__ == "__main__": 976if __name__ == "__main__":
970 main() 977 main()
diff --git a/train_ti.py b/train_ti.py
index ca5b113..ebac302 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -64,6 +64,12 @@ def parse_args():
64 help="The name of the current project.", 64 help="The name of the current project.",
65 ) 65 )
66 parser.add_argument( 66 parser.add_argument(
67 "--auto_cycles",
68 type=int,
69 default=1,
70 help="How many cycles to run automatically."
71 )
72 parser.add_argument(
67 "--placeholder_tokens", 73 "--placeholder_tokens",
68 type=str, 74 type=str,
69 nargs='*', 75 nargs='*',
@@ -869,10 +875,15 @@ def main():
869 mid_point=args.lr_mid_point, 875 mid_point=args.lr_mid_point,
870 ) 876 )
871 877
872 continue_training = True 878 training_iter = 0
873 training_iter = 1 879
880 while True:
881 training_iter += 1
882 if training_iter > args.auto_cycles:
883 response = input("Run another cycle? [y/n] ")
884 if response.lower().strip() == "n":
885 break
874 886
875 while continue_training:
876 print("") 887 print("")
877 print(f"------------ TI cycle {training_iter} ------------") 888 print(f"------------ TI cycle {training_iter} ------------")
878 print("") 889 print("")
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 289d6bd..9cdc1bb 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -119,7 +119,7 @@ def textual_inversion_strategy_callbacks(
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters())
120 120
121 if use_emb_decay and w is not None: 121 if use_emb_decay and w is not None:
122 lr = lrs["emb"] or lrs["0"] 122 lr = lrs["emb"] if "emb" in lrs else lrs["0"]
123 lambda_ = emb_decay * lr 123 lambda_ = emb_decay * lr
124 124
125 if lambda_ != 0: 125 if lambda_ != 0: