From dd4daf3f56483f44122500dc4905309541b7ef81 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 17:44:11 +0200 Subject: Update --- train_lora.py | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/train_lora.py b/train_lora.py index e4b5546..d8a4880 100644 --- a/train_lora.py +++ b/train_lora.py @@ -322,6 +322,17 @@ def parse_args(): default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) + parser.add_argument( + "--learning_rate_emb", + type=float, + default=1e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--train_emb", + action="store_true", + help="Keep training text embeddings.", + ) parser.add_argument( "--scale_lr", action="store_true", @@ -731,11 +742,16 @@ def main(): args.learning_rate_pti * args.pti_gradient_accumulation_steps * args.pti_batch_size * accelerator.num_processes ) + args.learning_rate_emb = ( + args.learning_rate_emb * args.pti_gradient_accumulation_steps * + args.pti_batch_size * accelerator.num_processes + ) if args.find_lr: args.learning_rate_unet = 1e-6 args.learning_rate_text = 1e-6 args.learning_rate_pti = 1e-6 + args.learning_rate_emb = 1e-6 args.lr_scheduler = "exponential_growth" if args.optimizer == 'adam8bit': @@ -794,6 +810,9 @@ def main(): args.lr_scheduler = "adafactor" args.lr_min_lr = args.learning_rate_unet args.learning_rate_unet = None + args.learning_rate_text = None + args.learning_rate_pti = None + args.learning_rate_emb = None elif args.optimizer == 'dadam': try: import dadaptation @@ -811,6 +830,8 @@ def main(): args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 + args.learning_rate_pti = 1.0 + args.learning_rate_emb = 1.0 elif args.optimizer == 'dadan': try: import dadaptation @@ -826,6 +847,8 @@ def main(): args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 + args.learning_rate_pti = 1.0 + args.learning_rate_emb = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") @@ -949,7 +972,8 @@ def main(): sample_frequency=pti_sample_frequency, ) - embeddings.persist() + if not args.train_emb: + embeddings.persist() # LORA # -------------------------------------------------------------------------------- @@ -974,13 +998,13 @@ def main(): params_to_optimize = [] group_labels = [] - # if len(args.placeholder_tokens) != 0: - # params_to_optimize.append({ - # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), - # "lr": args.learning_rate_text, - # "weight_decay": 0, - # }) - # group_labels.append("emb") + if len(args.placeholder_tokens) != 0 and args.train_emb: + params_to_optimize.append({ + "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), + "lr": args.learning_rate_emb, + "weight_decay": 0, + }) + group_labels.append("emb") params_to_optimize += [ { "params": ( -- cgit v1.2.3-70-g09d2