From 776213e99da4ec389575e797d93de8d8960fa010 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Apr 2023 16:21:52 +0200 Subject: Update --- models/clip/embeddings.py | 5 ++++- train_lora.py | 21 ++++++++++++++------- train_ti.py | 17 ++++++++++++++--- training/strategy/ti.py | 2 +- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index dc4708a..9be8256 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -97,7 +97,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) + input_ids = torch.arange( + self.token_embedding.num_embeddings, + device=self.token_override_embedding.mapping.device + ) embs, mask = self.token_override_embedding(input_ids) if embs is not None: input_ids = input_ids[mask] diff --git a/train_lora.py b/train_lora.py index 54c9e7a..e81742a 100644 --- a/train_lora.py +++ b/train_lora.py @@ -80,6 +80,12 @@ def parse_args(): default=None, help="The name of the current project.", ) + parser.add_argument( + "--auto_cycles", + type=int, + default=1, + help="How many cycles to run automatically." + ) parser.add_argument( "--placeholder_tokens", type=str, @@ -933,10 +939,15 @@ def main(): train_epochs=num_train_epochs, ) - continue_training = True - training_iter = 1 + training_iter = 0 + + while True: + training_iter += 1 + if training_iter > args.auto_cycles: + response = input("Run another cycle? [y/n] ") + if response.lower().strip() == "n": + break - while continue_training: print("") print(f"============ LoRA cycle {training_iter} ============") print("") @@ -961,10 +972,6 @@ def main(): sample_frequency=lora_sample_frequency, ) - response = input("Run another cycle? [y/n] ") - continue_training = response.lower().strip() != "n" - training_iter += 1 - if __name__ == "__main__": main() diff --git a/train_ti.py b/train_ti.py index ca5b113..ebac302 100644 --- a/train_ti.py +++ b/train_ti.py @@ -63,6 +63,12 @@ def parse_args(): default=None, help="The name of the current project.", ) + parser.add_argument( + "--auto_cycles", + type=int, + default=1, + help="How many cycles to run automatically." + ) parser.add_argument( "--placeholder_tokens", type=str, @@ -869,10 +875,15 @@ def main(): mid_point=args.lr_mid_point, ) - continue_training = True - training_iter = 1 + training_iter = 0 + + while True: + training_iter += 1 + if training_iter > args.auto_cycles: + response = input("Run another cycle? [y/n] ") + if response.lower().strip() == "n": + break - while continue_training: print("") print(f"------------ TI cycle {training_iter} ------------") print("") diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 289d6bd..9cdc1bb 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -119,7 +119,7 @@ def textual_inversion_strategy_callbacks( ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) if use_emb_decay and w is not None: - lr = lrs["emb"] or lrs["0"] + lr = lrs["emb"] if "emb" in lrs else lrs["0"] lambda_ = emb_decay * lr if lambda_ != 0: -- cgit v1.2.3-54-g00ecf