diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-24 10:53:16 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-24 10:53:16 +0100 |
| commit | 95adaea8b55d8e3755c035758bc649ae22548572 (patch) | |
| tree | 80239f0bc55b99615718a935be2caa2e1e68e20a /train_lora.py | |
| parent | Bring back Perlin offset noise (diff) | |
| download | textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.gz textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.bz2 textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.zip | |
Refactoring, fixed Lora training
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 73 |
1 files changed, 72 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index 8dd3c86..fa24cee 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -11,6 +11,7 @@ import torch.utils.checkpoint | |||
| 11 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
| 12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from peft import LoraConfig, LoraModel | ||
| 14 | from slugify import slugify | 15 | from slugify import slugify |
| 15 | 16 | ||
| 16 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
| @@ -21,6 +22,11 @@ from training.strategy.lora import lora_strategy | |||
| 21 | from training.optimization import get_scheduler | 22 | from training.optimization import get_scheduler |
| 22 | from training.util import save_args | 23 | from training.util import save_args |
| 23 | 24 | ||
| 25 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | ||
| 26 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | ||
| 27 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | ||
| 28 | |||
| 29 | |||
| 24 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
| 25 | 31 | ||
| 26 | 32 | ||
| @@ -176,6 +182,54 @@ def parse_args(): | |||
| 176 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 182 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 177 | ) | 183 | ) |
| 178 | parser.add_argument( | 184 | parser.add_argument( |
| 185 | "--lora_r", | ||
| 186 | type=int, | ||
| 187 | default=8, | ||
| 188 | help="Lora rank, only used if use_lora is True" | ||
| 189 | ) | ||
| 190 | parser.add_argument( | ||
| 191 | "--lora_alpha", | ||
| 192 | type=int, | ||
| 193 | default=32, | ||
| 194 | help="Lora alpha, only used if use_lora is True" | ||
| 195 | ) | ||
| 196 | parser.add_argument( | ||
| 197 | "--lora_dropout", | ||
| 198 | type=float, | ||
| 199 | default=0.0, | ||
| 200 | help="Lora dropout, only used if use_lora is True" | ||
| 201 | ) | ||
| 202 | parser.add_argument( | ||
| 203 | "--lora_bias", | ||
| 204 | type=str, | ||
| 205 | default="none", | ||
| 206 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True", | ||
| 207 | ) | ||
| 208 | parser.add_argument( | ||
| 209 | "--lora_text_encoder_r", | ||
| 210 | type=int, | ||
| 211 | default=8, | ||
| 212 | help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True", | ||
| 213 | ) | ||
| 214 | parser.add_argument( | ||
| 215 | "--lora_text_encoder_alpha", | ||
| 216 | type=int, | ||
| 217 | default=32, | ||
| 218 | help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True", | ||
| 219 | ) | ||
| 220 | parser.add_argument( | ||
| 221 | "--lora_text_encoder_dropout", | ||
| 222 | type=float, | ||
| 223 | default=0.0, | ||
| 224 | help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True", | ||
| 225 | ) | ||
| 226 | parser.add_argument( | ||
| 227 | "--lora_text_encoder_bias", | ||
| 228 | type=str, | ||
| 229 | default="none", | ||
| 230 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", | ||
| 231 | ) | ||
| 232 | parser.add_argument( | ||
| 179 | "--find_lr", | 233 | "--find_lr", |
| 180 | action="store_true", | 234 | action="store_true", |
| 181 | help="Automatically find a learning rate (no training).", | 235 | help="Automatically find a learning rate (no training).", |
| @@ -424,13 +478,30 @@ def main(): | |||
| 424 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 478 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 425 | args.pretrained_model_name_or_path) | 479 | args.pretrained_model_name_or_path) |
| 426 | 480 | ||
| 481 | unet_config = LoraConfig( | ||
| 482 | r=args.lora_r, | ||
| 483 | lora_alpha=args.lora_alpha, | ||
| 484 | target_modules=UNET_TARGET_MODULES, | ||
| 485 | lora_dropout=args.lora_dropout, | ||
| 486 | bias=args.lora_bias, | ||
| 487 | ) | ||
| 488 | unet = LoraModel(unet_config, unet) | ||
| 489 | |||
| 490 | text_encoder_config = LoraConfig( | ||
| 491 | r=args.lora_text_encoder_r, | ||
| 492 | lora_alpha=args.lora_text_encoder_alpha, | ||
| 493 | target_modules=TEXT_ENCODER_TARGET_MODULES, | ||
| 494 | lora_dropout=args.lora_text_encoder_dropout, | ||
| 495 | bias=args.lora_text_encoder_bias, | ||
| 496 | ) | ||
| 497 | text_encoder = LoraModel(text_encoder_config, text_encoder) | ||
| 498 | |||
| 427 | vae.enable_slicing() | 499 | vae.enable_slicing() |
| 428 | vae.set_use_memory_efficient_attention_xformers(True) | 500 | vae.set_use_memory_efficient_attention_xformers(True) |
| 429 | unet.enable_xformers_memory_efficient_attention() | 501 | unet.enable_xformers_memory_efficient_attention() |
| 430 | 502 | ||
| 431 | if args.gradient_checkpointing: | 503 | if args.gradient_checkpointing: |
| 432 | unet.enable_gradient_checkpointing() | 504 | unet.enable_gradient_checkpointing() |
| 433 | text_encoder.gradient_checkpointing_enable() | ||
| 434 | 505 | ||
| 435 | if args.embeddings_dir is not None: | 506 | if args.embeddings_dir is not None: |
| 436 | embeddings_dir = Path(args.embeddings_dir) | 507 | embeddings_dir = Path(args.embeddings_dir) |
