From 95adaea8b55d8e3755c035758bc649ae22548572 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 24 Mar 2023 10:53:16 +0100 Subject: Refactoring, fixed Lora training --- train_lora.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 8dd3c86..fa24cee 100644 --- a/train_lora.py +++ b/train_lora.py @@ -11,6 +11,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed +from peft import LoraConfig, LoraModel from slugify import slugify from util.files import load_config, load_embeddings_from_dir @@ -21,6 +22,11 @@ from training.strategy.lora import lora_strategy from training.optimization import get_scheduler from training.util import save_args +# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py +UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] + + logger = get_logger(__name__) @@ -175,6 +181,54 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--lora_r", + type=int, + default=8, + help="Lora rank, only used if use_lora is True" + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=32, + help="Lora alpha, only used if use_lora is True" + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="Lora dropout, only used if use_lora is True" + ) + parser.add_argument( + "--lora_bias", + type=str, + default="none", + help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True", + ) + parser.add_argument( + "--lora_text_encoder_r", + type=int, + default=8, + help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True", + ) + parser.add_argument( + "--lora_text_encoder_alpha", + type=int, + default=32, + help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True", + ) + parser.add_argument( + "--lora_text_encoder_dropout", + type=float, + default=0.0, + help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True", + ) + parser.add_argument( + "--lora_text_encoder_bias", + type=str, + default="none", + help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", + ) parser.add_argument( "--find_lr", action="store_true", @@ -424,13 +478,30 @@ def main(): tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) + unet_config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=UNET_TARGET_MODULES, + lora_dropout=args.lora_dropout, + bias=args.lora_bias, + ) + unet = LoraModel(unet_config, unet) + + text_encoder_config = LoraConfig( + r=args.lora_text_encoder_r, + lora_alpha=args.lora_text_encoder_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + lora_dropout=args.lora_text_encoder_dropout, + bias=args.lora_text_encoder_bias, + ) + text_encoder = LoraModel(text_encoder_config, text_encoder) + vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) -- cgit v1.2.3-54-g00ecf