diff options
-rw-r--r-- | models/clip/embeddings.py | 5 | ||||
-rw-r--r-- | train_lora.py | 21 | ||||
-rw-r--r-- | train_ti.py | 17 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
4 files changed, 33 insertions, 12 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index dc4708a..9be8256 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -97,7 +97,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
97 | save_file({"embed": self.get_embed(input_ids)}, filename) | 97 | save_file({"embed": self.get_embed(input_ids)}, filename) |
98 | 98 | ||
99 | def persist(self): | 99 | def persist(self): |
100 | input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) | 100 | input_ids = torch.arange( |
101 | self.token_embedding.num_embeddings, | ||
102 | device=self.token_override_embedding.mapping.device | ||
103 | ) | ||
101 | embs, mask = self.token_override_embedding(input_ids) | 104 | embs, mask = self.token_override_embedding(input_ids) |
102 | if embs is not None: | 105 | if embs is not None: |
103 | input_ids = input_ids[mask] | 106 | input_ids = input_ids[mask] |
diff --git a/train_lora.py b/train_lora.py index 54c9e7a..e81742a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -81,6 +81,12 @@ def parse_args(): | |||
81 | help="The name of the current project.", | 81 | help="The name of the current project.", |
82 | ) | 82 | ) |
83 | parser.add_argument( | 83 | parser.add_argument( |
84 | "--auto_cycles", | ||
85 | type=int, | ||
86 | default=1, | ||
87 | help="How many cycles to run automatically." | ||
88 | ) | ||
89 | parser.add_argument( | ||
84 | "--placeholder_tokens", | 90 | "--placeholder_tokens", |
85 | type=str, | 91 | type=str, |
86 | nargs='*', | 92 | nargs='*', |
@@ -933,10 +939,15 @@ def main(): | |||
933 | train_epochs=num_train_epochs, | 939 | train_epochs=num_train_epochs, |
934 | ) | 940 | ) |
935 | 941 | ||
936 | continue_training = True | 942 | training_iter = 0 |
937 | training_iter = 1 | 943 | |
944 | while True: | ||
945 | training_iter += 1 | ||
946 | if training_iter > args.auto_cycles: | ||
947 | response = input("Run another cycle? [y/n] ") | ||
948 | if response.lower().strip() == "n": | ||
949 | break | ||
938 | 950 | ||
939 | while continue_training: | ||
940 | print("") | 951 | print("") |
941 | print(f"============ LoRA cycle {training_iter} ============") | 952 | print(f"============ LoRA cycle {training_iter} ============") |
942 | print("") | 953 | print("") |
@@ -961,10 +972,6 @@ def main(): | |||
961 | sample_frequency=lora_sample_frequency, | 972 | sample_frequency=lora_sample_frequency, |
962 | ) | 973 | ) |
963 | 974 | ||
964 | response = input("Run another cycle? [y/n] ") | ||
965 | continue_training = response.lower().strip() != "n" | ||
966 | training_iter += 1 | ||
967 | |||
968 | 975 | ||
969 | if __name__ == "__main__": | 976 | if __name__ == "__main__": |
970 | main() | 977 | main() |
diff --git a/train_ti.py b/train_ti.py index ca5b113..ebac302 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -64,6 +64,12 @@ def parse_args(): | |||
64 | help="The name of the current project.", | 64 | help="The name of the current project.", |
65 | ) | 65 | ) |
66 | parser.add_argument( | 66 | parser.add_argument( |
67 | "--auto_cycles", | ||
68 | type=int, | ||
69 | default=1, | ||
70 | help="How many cycles to run automatically." | ||
71 | ) | ||
72 | parser.add_argument( | ||
67 | "--placeholder_tokens", | 73 | "--placeholder_tokens", |
68 | type=str, | 74 | type=str, |
69 | nargs='*', | 75 | nargs='*', |
@@ -869,10 +875,15 @@ def main(): | |||
869 | mid_point=args.lr_mid_point, | 875 | mid_point=args.lr_mid_point, |
870 | ) | 876 | ) |
871 | 877 | ||
872 | continue_training = True | 878 | training_iter = 0 |
873 | training_iter = 1 | 879 | |
880 | while True: | ||
881 | training_iter += 1 | ||
882 | if training_iter > args.auto_cycles: | ||
883 | response = input("Run another cycle? [y/n] ") | ||
884 | if response.lower().strip() == "n": | ||
885 | break | ||
874 | 886 | ||
875 | while continue_training: | ||
876 | print("") | 887 | print("") |
877 | print(f"------------ TI cycle {training_iter} ------------") | 888 | print(f"------------ TI cycle {training_iter} ------------") |
878 | print("") | 889 | print("") |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 289d6bd..9cdc1bb 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -119,7 +119,7 @@ def textual_inversion_strategy_callbacks( | |||
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) |
120 | 120 | ||
121 | if use_emb_decay and w is not None: | 121 | if use_emb_decay and w is not None: |
122 | lr = lrs["emb"] or lrs["0"] | 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
123 | lambda_ = emb_decay * lr | 123 | lambda_ = emb_decay * lr |
124 | 124 | ||
125 | if lambda_ != 0: | 125 | if lambda_ != 0: |