diff options
author | Volpeon <git@volpeon.ink> | 2023-04-06 16:06:04 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-06 16:06:04 +0200 |
commit | ab24e5cbd8283ad4ced486e1369484ebf9e3962d (patch) | |
tree | 7d47b7cd38e7313071ad4a671b14f8a23dcd7389 /train_lora.py | |
parent | MinSNR code from diffusers (diff) | |
download | textual-inversion-diff-ab24e5cbd8283ad4ced486e1369484ebf9e3962d.tar.gz textual-inversion-diff-ab24e5cbd8283ad4ced486e1369484ebf9e3962d.tar.bz2 textual-inversion-diff-ab24e5cbd8283ad4ced486e1369484ebf9e3962d.zip |
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 51 |
1 files changed, 32 insertions, 19 deletions
diff --git a/train_lora.py b/train_lora.py index 73b3e19..1ca56d9 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -1,7 +1,6 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import itertools | ||
5 | from pathlib import Path | 4 | from pathlib import Path |
6 | from functools import partial | 5 | from functools import partial |
7 | import math | 6 | import math |
@@ -247,9 +246,15 @@ def parse_args(): | |||
247 | help="Automatically find a learning rate (no training).", | 246 | help="Automatically find a learning rate (no training).", |
248 | ) | 247 | ) |
249 | parser.add_argument( | 248 | parser.add_argument( |
250 | "--learning_rate", | 249 | "--learning_rate_unet", |
251 | type=float, | 250 | type=float, |
252 | default=2e-6, | 251 | default=1e-4, |
252 | help="Initial learning rate (after the potential warmup period) to use.", | ||
253 | ) | ||
254 | parser.add_argument( | ||
255 | "--learning_rate_text", | ||
256 | type=float, | ||
257 | default=5e-5, | ||
253 | help="Initial learning rate (after the potential warmup period) to use.", | 258 | help="Initial learning rate (after the potential warmup period) to use.", |
254 | ) | 259 | ) |
255 | parser.add_argument( | 260 | parser.add_argument( |
@@ -548,13 +553,18 @@ def main(): | |||
548 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 553 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
549 | 554 | ||
550 | if args.scale_lr: | 555 | if args.scale_lr: |
551 | args.learning_rate = ( | 556 | args.learning_rate_unet = ( |
552 | args.learning_rate * args.gradient_accumulation_steps * | 557 | args.learning_rate_unet * args.gradient_accumulation_steps * |
558 | args.train_batch_size * accelerator.num_processes | ||
559 | ) | ||
560 | args.learning_rate_text = ( | ||
561 | args.learning_rate_text * args.gradient_accumulation_steps * | ||
553 | args.train_batch_size * accelerator.num_processes | 562 | args.train_batch_size * accelerator.num_processes |
554 | ) | 563 | ) |
555 | 564 | ||
556 | if args.find_lr: | 565 | if args.find_lr: |
557 | args.learning_rate = 1e-6 | 566 | args.learning_rate_unet = 1e-6 |
567 | args.learning_rate_text = 1e-6 | ||
558 | args.lr_scheduler = "exponential_growth" | 568 | args.lr_scheduler = "exponential_growth" |
559 | 569 | ||
560 | if args.optimizer == 'adam8bit': | 570 | if args.optimizer == 'adam8bit': |
@@ -611,8 +621,8 @@ def main(): | |||
611 | ) | 621 | ) |
612 | 622 | ||
613 | args.lr_scheduler = "adafactor" | 623 | args.lr_scheduler = "adafactor" |
614 | args.lr_min_lr = args.learning_rate | 624 | args.lr_min_lr = args.learning_rate_unet |
615 | args.learning_rate = None | 625 | args.learning_rate_unet = None |
616 | elif args.optimizer == 'dadam': | 626 | elif args.optimizer == 'dadam': |
617 | try: | 627 | try: |
618 | import dadaptation | 628 | import dadaptation |
@@ -628,7 +638,8 @@ def main(): | |||
628 | d0=args.dadaptation_d0, | 638 | d0=args.dadaptation_d0, |
629 | ) | 639 | ) |
630 | 640 | ||
631 | args.learning_rate = 1.0 | 641 | args.learning_rate_unet = 1.0 |
642 | args.learning_rate_text = 1.0 | ||
632 | elif args.optimizer == 'dadan': | 643 | elif args.optimizer == 'dadan': |
633 | try: | 644 | try: |
634 | import dadaptation | 645 | import dadaptation |
@@ -642,7 +653,8 @@ def main(): | |||
642 | d0=args.dadaptation_d0, | 653 | d0=args.dadaptation_d0, |
643 | ) | 654 | ) |
644 | 655 | ||
645 | args.learning_rate = 1.0 | 656 | args.learning_rate_unet = 1.0 |
657 | args.learning_rate_text = 1.0 | ||
646 | else: | 658 | else: |
647 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 659 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
648 | 660 | ||
@@ -695,15 +707,16 @@ def main(): | |||
695 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 707 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
696 | 708 | ||
697 | optimizer = create_optimizer( | 709 | optimizer = create_optimizer( |
698 | ( | 710 | [ |
699 | param | 711 | { |
700 | for param in itertools.chain( | 712 | "params": unet.parameters(), |
701 | unet.parameters(), | 713 | "lr": args.learning_rate_unet, |
702 | text_encoder.parameters(), | 714 | }, |
703 | ) | 715 | { |
704 | if param.requires_grad | 716 | "params": text_encoder.parameters(), |
705 | ), | 717 | "lr": args.learning_rate_text, |
706 | lr=args.learning_rate, | 718 | }, |
719 | ] | ||
707 | ) | 720 | ) |
708 | 721 | ||
709 | lr_scheduler = get_scheduler( | 722 | lr_scheduler = get_scheduler( |