summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-24 10:53:16 +0100
committerVolpeon <git@volpeon.ink>2023-03-24 10:53:16 +0100
commit95adaea8b55d8e3755c035758bc649ae22548572 (patch)
tree80239f0bc55b99615718a935be2caa2e1e68e20a /train_lora.py
parentBring back Perlin offset noise (diff)
downloadtextual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.gz
textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.bz2
textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.zip
Refactoring, fixed Lora training
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py73
1 files changed, 72 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py
index 8dd3c86..fa24cee 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -11,6 +11,7 @@ import torch.utils.checkpoint
11from accelerate import Accelerator 11from accelerate import Accelerator
12from accelerate.logging import get_logger 12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from peft import LoraConfig, LoraModel
14from slugify import slugify 15from slugify import slugify
15 16
16from util.files import load_config, load_embeddings_from_dir 17from util.files import load_config, load_embeddings_from_dir
@@ -21,6 +22,11 @@ from training.strategy.lora import lora_strategy
21from training.optimization import get_scheduler 22from training.optimization import get_scheduler
22from training.util import save_args 23from training.util import save_args
23 24
25# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
26UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
27TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
28
29
24logger = get_logger(__name__) 30logger = get_logger(__name__)
25 31
26 32
@@ -176,6 +182,54 @@ def parse_args():
176 help="Number of updates steps to accumulate before performing a backward/update pass.", 182 help="Number of updates steps to accumulate before performing a backward/update pass.",
177 ) 183 )
178 parser.add_argument( 184 parser.add_argument(
185 "--lora_r",
186 type=int,
187 default=8,
188 help="Lora rank, only used if use_lora is True"
189 )
190 parser.add_argument(
191 "--lora_alpha",
192 type=int,
193 default=32,
194 help="Lora alpha, only used if use_lora is True"
195 )
196 parser.add_argument(
197 "--lora_dropout",
198 type=float,
199 default=0.0,
200 help="Lora dropout, only used if use_lora is True"
201 )
202 parser.add_argument(
203 "--lora_bias",
204 type=str,
205 default="none",
206 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True",
207 )
208 parser.add_argument(
209 "--lora_text_encoder_r",
210 type=int,
211 default=8,
212 help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True",
213 )
214 parser.add_argument(
215 "--lora_text_encoder_alpha",
216 type=int,
217 default=32,
218 help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True",
219 )
220 parser.add_argument(
221 "--lora_text_encoder_dropout",
222 type=float,
223 default=0.0,
224 help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True",
225 )
226 parser.add_argument(
227 "--lora_text_encoder_bias",
228 type=str,
229 default="none",
230 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
231 )
232 parser.add_argument(
179 "--find_lr", 233 "--find_lr",
180 action="store_true", 234 action="store_true",
181 help="Automatically find a learning rate (no training).", 235 help="Automatically find a learning rate (no training).",
@@ -424,13 +478,30 @@ def main():
424 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 478 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
425 args.pretrained_model_name_or_path) 479 args.pretrained_model_name_or_path)
426 480
481 unet_config = LoraConfig(
482 r=args.lora_r,
483 lora_alpha=args.lora_alpha,
484 target_modules=UNET_TARGET_MODULES,
485 lora_dropout=args.lora_dropout,
486 bias=args.lora_bias,
487 )
488 unet = LoraModel(unet_config, unet)
489
490 text_encoder_config = LoraConfig(
491 r=args.lora_text_encoder_r,
492 lora_alpha=args.lora_text_encoder_alpha,
493 target_modules=TEXT_ENCODER_TARGET_MODULES,
494 lora_dropout=args.lora_text_encoder_dropout,
495 bias=args.lora_text_encoder_bias,
496 )
497 text_encoder = LoraModel(text_encoder_config, text_encoder)
498
427 vae.enable_slicing() 499 vae.enable_slicing()
428 vae.set_use_memory_efficient_attention_xformers(True) 500 vae.set_use_memory_efficient_attention_xformers(True)
429 unet.enable_xformers_memory_efficient_attention() 501 unet.enable_xformers_memory_efficient_attention()
430 502
431 if args.gradient_checkpointing: 503 if args.gradient_checkpointing:
432 unet.enable_gradient_checkpointing() 504 unet.enable_gradient_checkpointing()
433 text_encoder.gradient_checkpointing_enable()
434 505
435 if args.embeddings_dir is not None: 506 if args.embeddings_dir is not None:
436 embeddings_dir = Path(args.embeddings_dir) 507 embeddings_dir = Path(args.embeddings_dir)