summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py51
1 files changed, 32 insertions, 19 deletions
diff --git a/train_lora.py b/train_lora.py
index 73b3e19..1ca56d9 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -1,7 +1,6 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import itertools
5from pathlib import Path 4from pathlib import Path
6from functools import partial 5from functools import partial
7import math 6import math
@@ -247,9 +246,15 @@ def parse_args():
247 help="Automatically find a learning rate (no training).", 246 help="Automatically find a learning rate (no training).",
248 ) 247 )
249 parser.add_argument( 248 parser.add_argument(
250 "--learning_rate", 249 "--learning_rate_unet",
251 type=float, 250 type=float,
252 default=2e-6, 251 default=1e-4,
252 help="Initial learning rate (after the potential warmup period) to use.",
253 )
254 parser.add_argument(
255 "--learning_rate_text",
256 type=float,
257 default=5e-5,
253 help="Initial learning rate (after the potential warmup period) to use.", 258 help="Initial learning rate (after the potential warmup period) to use.",
254 ) 259 )
255 parser.add_argument( 260 parser.add_argument(
@@ -548,13 +553,18 @@ def main():
548 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 553 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
549 554
550 if args.scale_lr: 555 if args.scale_lr:
551 args.learning_rate = ( 556 args.learning_rate_unet = (
552 args.learning_rate * args.gradient_accumulation_steps * 557 args.learning_rate_unet * args.gradient_accumulation_steps *
558 args.train_batch_size * accelerator.num_processes
559 )
560 args.learning_rate_text = (
561 args.learning_rate_text * args.gradient_accumulation_steps *
553 args.train_batch_size * accelerator.num_processes 562 args.train_batch_size * accelerator.num_processes
554 ) 563 )
555 564
556 if args.find_lr: 565 if args.find_lr:
557 args.learning_rate = 1e-6 566 args.learning_rate_unet = 1e-6
567 args.learning_rate_text = 1e-6
558 args.lr_scheduler = "exponential_growth" 568 args.lr_scheduler = "exponential_growth"
559 569
560 if args.optimizer == 'adam8bit': 570 if args.optimizer == 'adam8bit':
@@ -611,8 +621,8 @@ def main():
611 ) 621 )
612 622
613 args.lr_scheduler = "adafactor" 623 args.lr_scheduler = "adafactor"
614 args.lr_min_lr = args.learning_rate 624 args.lr_min_lr = args.learning_rate_unet
615 args.learning_rate = None 625 args.learning_rate_unet = None
616 elif args.optimizer == 'dadam': 626 elif args.optimizer == 'dadam':
617 try: 627 try:
618 import dadaptation 628 import dadaptation
@@ -628,7 +638,8 @@ def main():
628 d0=args.dadaptation_d0, 638 d0=args.dadaptation_d0,
629 ) 639 )
630 640
631 args.learning_rate = 1.0 641 args.learning_rate_unet = 1.0
642 args.learning_rate_text = 1.0
632 elif args.optimizer == 'dadan': 643 elif args.optimizer == 'dadan':
633 try: 644 try:
634 import dadaptation 645 import dadaptation
@@ -642,7 +653,8 @@ def main():
642 d0=args.dadaptation_d0, 653 d0=args.dadaptation_d0,
643 ) 654 )
644 655
645 args.learning_rate = 1.0 656 args.learning_rate_unet = 1.0
657 args.learning_rate_text = 1.0
646 else: 658 else:
647 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 659 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
648 660
@@ -695,15 +707,16 @@ def main():
695 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 707 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
696 708
697 optimizer = create_optimizer( 709 optimizer = create_optimizer(
698 ( 710 [
699 param 711 {
700 for param in itertools.chain( 712 "params": unet.parameters(),
701 unet.parameters(), 713 "lr": args.learning_rate_unet,
702 text_encoder.parameters(), 714 },
703 ) 715 {
704 if param.requires_grad 716 "params": text_encoder.parameters(),
705 ), 717 "lr": args.learning_rate_text,
706 lr=args.learning_rate, 718 },
719 ]
707 ) 720 )
708 721
709 lr_scheduler = get_scheduler( 722 lr_scheduler = get_scheduler(