diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 73 |
1 files changed, 72 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index 8dd3c86..fa24cee 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -11,6 +11,7 @@ import torch.utils.checkpoint | |||
11 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from peft import LoraConfig, LoraModel | ||
14 | from slugify import slugify | 15 | from slugify import slugify |
15 | 16 | ||
16 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
@@ -21,6 +22,11 @@ from training.strategy.lora import lora_strategy | |||
21 | from training.optimization import get_scheduler | 22 | from training.optimization import get_scheduler |
22 | from training.util import save_args | 23 | from training.util import save_args |
23 | 24 | ||
25 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | ||
26 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | ||
27 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | ||
28 | |||
29 | |||
24 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
25 | 31 | ||
26 | 32 | ||
@@ -176,6 +182,54 @@ def parse_args(): | |||
176 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 182 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
177 | ) | 183 | ) |
178 | parser.add_argument( | 184 | parser.add_argument( |
185 | "--lora_r", | ||
186 | type=int, | ||
187 | default=8, | ||
188 | help="Lora rank, only used if use_lora is True" | ||
189 | ) | ||
190 | parser.add_argument( | ||
191 | "--lora_alpha", | ||
192 | type=int, | ||
193 | default=32, | ||
194 | help="Lora alpha, only used if use_lora is True" | ||
195 | ) | ||
196 | parser.add_argument( | ||
197 | "--lora_dropout", | ||
198 | type=float, | ||
199 | default=0.0, | ||
200 | help="Lora dropout, only used if use_lora is True" | ||
201 | ) | ||
202 | parser.add_argument( | ||
203 | "--lora_bias", | ||
204 | type=str, | ||
205 | default="none", | ||
206 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True", | ||
207 | ) | ||
208 | parser.add_argument( | ||
209 | "--lora_text_encoder_r", | ||
210 | type=int, | ||
211 | default=8, | ||
212 | help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True", | ||
213 | ) | ||
214 | parser.add_argument( | ||
215 | "--lora_text_encoder_alpha", | ||
216 | type=int, | ||
217 | default=32, | ||
218 | help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True", | ||
219 | ) | ||
220 | parser.add_argument( | ||
221 | "--lora_text_encoder_dropout", | ||
222 | type=float, | ||
223 | default=0.0, | ||
224 | help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True", | ||
225 | ) | ||
226 | parser.add_argument( | ||
227 | "--lora_text_encoder_bias", | ||
228 | type=str, | ||
229 | default="none", | ||
230 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", | ||
231 | ) | ||
232 | parser.add_argument( | ||
179 | "--find_lr", | 233 | "--find_lr", |
180 | action="store_true", | 234 | action="store_true", |
181 | help="Automatically find a learning rate (no training).", | 235 | help="Automatically find a learning rate (no training).", |
@@ -424,13 +478,30 @@ def main(): | |||
424 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 478 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
425 | args.pretrained_model_name_or_path) | 479 | args.pretrained_model_name_or_path) |
426 | 480 | ||
481 | unet_config = LoraConfig( | ||
482 | r=args.lora_r, | ||
483 | lora_alpha=args.lora_alpha, | ||
484 | target_modules=UNET_TARGET_MODULES, | ||
485 | lora_dropout=args.lora_dropout, | ||
486 | bias=args.lora_bias, | ||
487 | ) | ||
488 | unet = LoraModel(unet_config, unet) | ||
489 | |||
490 | text_encoder_config = LoraConfig( | ||
491 | r=args.lora_text_encoder_r, | ||
492 | lora_alpha=args.lora_text_encoder_alpha, | ||
493 | target_modules=TEXT_ENCODER_TARGET_MODULES, | ||
494 | lora_dropout=args.lora_text_encoder_dropout, | ||
495 | bias=args.lora_text_encoder_bias, | ||
496 | ) | ||
497 | text_encoder = LoraModel(text_encoder_config, text_encoder) | ||
498 | |||
427 | vae.enable_slicing() | 499 | vae.enable_slicing() |
428 | vae.set_use_memory_efficient_attention_xformers(True) | 500 | vae.set_use_memory_efficient_attention_xformers(True) |
429 | unet.enable_xformers_memory_efficient_attention() | 501 | unet.enable_xformers_memory_efficient_attention() |
430 | 502 | ||
431 | if args.gradient_checkpointing: | 503 | if args.gradient_checkpointing: |
432 | unet.enable_gradient_checkpointing() | 504 | unet.enable_gradient_checkpointing() |
433 | text_encoder.gradient_checkpointing_enable() | ||
434 | 505 | ||
435 | if args.embeddings_dir is not None: | 506 | if args.embeddings_dir is not None: |
436 | embeddings_dir = Path(args.embeddings_dir) | 507 | embeddings_dir = Path(args.embeddings_dir) |