From ab24e5cbd8283ad4ced486e1369484ebf9e3962d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Apr 2023 16:06:04 +0200 Subject: Update --- train_lora.py | 51 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 19 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 73b3e19..1ca56d9 100644 --- a/train_lora.py +++ b/train_lora.py @@ -1,7 +1,6 @@ import argparse import datetime import logging -import itertools from pathlib import Path from functools import partial import math @@ -247,9 +246,15 @@ def parse_args(): help="Automatically find a learning rate (no training).", ) parser.add_argument( - "--learning_rate", + "--learning_rate_unet", type=float, - default=2e-6, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--learning_rate_text", + type=float, + default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -548,13 +553,18 @@ def main(): print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * + args.learning_rate_unet = ( + args.learning_rate_unet * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + args.learning_rate_text = ( + args.learning_rate_text * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.find_lr: - args.learning_rate = 1e-6 + args.learning_rate_unet = 1e-6 + args.learning_rate_text = 1e-6 args.lr_scheduler = "exponential_growth" if args.optimizer == 'adam8bit': @@ -611,8 +621,8 @@ def main(): ) args.lr_scheduler = "adafactor" - args.lr_min_lr = args.learning_rate - args.learning_rate = None + args.lr_min_lr = args.learning_rate_unet + args.learning_rate_unet = None elif args.optimizer == 'dadam': try: import dadaptation @@ -628,7 +638,8 @@ def main(): d0=args.dadaptation_d0, ) - args.learning_rate = 1.0 + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 elif args.optimizer == 'dadan': try: import dadaptation @@ -642,7 +653,8 @@ def main(): d0=args.dadaptation_d0, ) - args.learning_rate = 1.0 + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") @@ -695,15 +707,16 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - ( - param - for param in itertools.chain( - unet.parameters(), - text_encoder.parameters(), - ) - if param.requires_grad - ), - lr=args.learning_rate, + [ + { + "params": unet.parameters(), + "lr": args.learning_rate_unet, + }, + { + "params": text_encoder.parameters(), + "lr": args.learning_rate_text, + }, + ] ) lr_scheduler = get_scheduler( -- cgit v1.2.3-54-g00ecf