From a0b63ee7f4a8c793c0d200c86ef07677aa4cbf2e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 13 Apr 2023 07:14:24 +0200 Subject: Update --- train_lora.py | 73 +++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 27 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 29e40b2..073e939 100644 --- a/train_lora.py +++ b/train_lora.py @@ -86,6 +86,12 @@ def parse_args(): default=1, help="How many cycles to run automatically." ) + parser.add_argument( + "--cycle_decay", + type=float, + default=1.0, + help="Learning rate decay per cycle." + ) parser.add_argument( "--placeholder_tokens", type=str, @@ -924,39 +930,15 @@ def main(): if args.sample_num is not None: lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) - params_to_optimize = [] group_labels = [] if len(args.placeholder_tokens) != 0: - params_to_optimize.append({ - "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), - "lr": args.learning_rate_emb, - "weight_decay": 0, - }) group_labels.append("emb") - params_to_optimize += [ - { - "params": ( - param - for param in unet.parameters() - if param.requires_grad - ), - "lr": args.learning_rate_unet, - }, - { - "params": ( - param - for param in itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - ) - if param.requires_grad - ), - "lr": args.learning_rate_text, - }, - ] group_labels += ["unet", "text"] training_iter = 0 + learning_rate_emb = args.learning_rate_emb + learning_rate_unet = args.learning_rate_unet + learning_rate_text = args.learning_rate_text lora_project = "lora" @@ -973,6 +955,37 @@ def main(): print(f"============ LoRA cycle {training_iter + 1} ============") print("") + params_to_optimize = [] + + if len(args.placeholder_tokens) != 0: + params_to_optimize.append({ + "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), + "lr": learning_rate_emb, + "weight_decay": 0, + }) + group_labels.append("emb") + params_to_optimize += [ + { + "params": ( + param + for param in unet.parameters() + if param.requires_grad + ), + "lr": learning_rate_unet, + }, + { + "params": ( + param + for param in itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ) + if param.requires_grad + ), + "lr": learning_rate_text, + }, + ] + lora_optimizer = create_optimizer(params_to_optimize) lora_lr_scheduler = create_lr_scheduler( @@ -1002,6 +1015,12 @@ def main(): ) training_iter += 1 + if args.learning_rate_emb is not None: + learning_rate_emb *= args.cycle_decay + if args.learning_rate_unet is not None: + learning_rate_unet *= args.cycle_decay + if args.learning_rate_text is not None: + learning_rate_text *= args.cycle_decay accelerator.end_training() -- cgit v1.2.3-70-g09d2