From 95adaea8b55d8e3755c035758bc649ae22548572 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 24 Mar 2023 10:53:16 +0100 Subject: Refactoring, fixed Lora training --- train_lora.py | 73 ++++++++++++++++++++++++++++++++++++++++- training/functional.py | 9 ++--- training/strategy/dreambooth.py | 17 ++++------ training/strategy/lora.py | 49 +++++++-------------------- training/strategy/ti.py | 22 +++++++------ 5 files changed, 104 insertions(+), 66 deletions(-) diff --git a/train_lora.py b/train_lora.py index 8dd3c86..fa24cee 100644 --- a/train_lora.py +++ b/train_lora.py @@ -11,6 +11,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed +from peft import LoraConfig, LoraModel from slugify import slugify from util.files import load_config, load_embeddings_from_dir @@ -21,6 +22,11 @@ from training.strategy.lora import lora_strategy from training.optimization import get_scheduler from training.util import save_args +# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py +UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] + + logger = get_logger(__name__) @@ -175,6 +181,54 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--lora_r", + type=int, + default=8, + help="Lora rank, only used if use_lora is True" + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=32, + help="Lora alpha, only used if use_lora is True" + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="Lora dropout, only used if use_lora is True" + ) + parser.add_argument( + "--lora_bias", + type=str, + default="none", + help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True", + ) + parser.add_argument( + "--lora_text_encoder_r", + type=int, + default=8, + help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True", + ) + parser.add_argument( + "--lora_text_encoder_alpha", + type=int, + default=32, + help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True", + ) + parser.add_argument( + "--lora_text_encoder_dropout", + type=float, + default=0.0, + help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True", + ) + parser.add_argument( + "--lora_text_encoder_bias", + type=str, + default="none", + help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", + ) parser.add_argument( "--find_lr", action="store_true", @@ -424,13 +478,30 @@ def main(): tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) + unet_config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=UNET_TARGET_MODULES, + lora_dropout=args.lora_dropout, + bias=args.lora_bias, + ) + unet = LoraModel(unet_config, unet) + + text_encoder_config = LoraConfig( + r=args.lora_text_encoder_r, + lora_alpha=args.lora_text_encoder_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + lora_dropout=args.lora_text_encoder_dropout, + bias=args.lora_text_encoder_bias, + ) + text_encoder = LoraModel(text_encoder_config, text_encoder) + vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) diff --git a/training/functional.py b/training/functional.py index a5b339d..ee73ab2 100644 --- a/training/functional.py +++ b/training/functional.py @@ -34,7 +34,6 @@ def const(result=None): @dataclass class TrainingCallbacks(): - on_prepare: Callable[[], None] = const() on_accum_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) @@ -620,10 +619,8 @@ def train( kwargs.update(extra) vae.to(accelerator.device, dtype=dtype) - - for model in (unet, text_encoder, vae): - model.requires_grad_(False) - model.eval() + vae.requires_grad_(False) + vae.eval() callbacks = strategy.callbacks( accelerator=accelerator, @@ -636,8 +633,6 @@ def train( **kwargs, ) - callbacks.on_prepare() - loss_step_ = partial( loss_step, vae, diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 28fccff..9808027 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -74,6 +74,7 @@ def dreambooth_strategy_callbacks( power=ema_power, max_value=ema_max_decay, ) + ema_unet.to(accelerator.device) else: ema_unet = None @@ -86,14 +87,6 @@ def dreambooth_strategy_callbacks( def on_accum_model(): return unet - def on_prepare(): - unet.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(True) - text_encoder.text_model.final_layer_norm.requires_grad_(True) - - if ema_unet is not None: - ema_unet.to(accelerator.device) - @contextmanager def on_train(epoch: int): tokenizer.train() @@ -181,7 +174,6 @@ def dreambooth_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, @@ -203,7 +195,12 @@ def dreambooth_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + + text_encoder.text_model.embeddings.requires_grad_(False) + + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} dreambooth_strategy = TrainingStrategy( diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1c8fad6..3971eae 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -10,18 +10,12 @@ from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler -from peft import LoraConfig, LoraModel, get_peft_model_state_dict -from peft.tuners.lora import mark_only_lora_as_trainable +from peft import get_peft_model_state_dict from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py -UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] -TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] - - def lora_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, @@ -61,10 +55,6 @@ def lora_strategy_callbacks( image_size=sample_image_size, ) - def on_prepare(): - mark_only_lora_as_trainable(unet.model, unet.peft_config.bias) - mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias) - def on_accum_model(): return unet @@ -93,15 +83,15 @@ def lora_strategy_callbacks( text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) lora_config = {} - state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) - lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) + state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) + lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( - text_encoder, state_dict=accelerator.get_state_dict(text_encoder) + text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) + lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) accelerator.print(state_dict) accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") @@ -111,11 +101,16 @@ def lora_strategy_callbacks( @torch.no_grad() def on_sample(step): + vae_dtype = vae.dtype + vae.to(dtype=text_encoder.dtype) + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + vae.to(dtype=vae_dtype) + del unet_ del text_encoder_ @@ -123,7 +118,6 @@ def lora_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, @@ -147,28 +141,7 @@ def lora_prepare( lora_bias: str = "none", **kwargs ): - unet_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - target_modules=UNET_TARGET_MODULES, - lora_dropout=lora_dropout, - bias=lora_bias, - ) - unet = LoraModel(unet_config, unet) - - text_encoder_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - target_modules=TEXT_ENCODER_TARGET_MODULES, - lora_dropout=lora_dropout, - bias=lora_bias, - ) - text_encoder = LoraModel(text_encoder_config, text_encoder) - - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) - - return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} + return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) lora_strategy = TrainingStrategy( diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 2038e34..10bc6d7 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -78,6 +78,7 @@ def textual_inversion_strategy_callbacks( power=ema_power, max_value=ema_max_decay, ) + ema_embeddings.to(accelerator.device) else: ema_embeddings = None @@ -92,15 +93,6 @@ def textual_inversion_strategy_callbacks( def on_accum_model(): return text_encoder.text_model.embeddings.temp_token_embedding - def on_prepare(): - text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) - - if ema_embeddings is not None: - ema_embeddings.to(accelerator.device) - - if gradient_checkpointing: - unet.train() - @contextmanager def on_train(epoch: int): tokenizer.train() @@ -177,7 +169,6 @@ def textual_inversion_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, @@ -197,6 +188,7 @@ def textual_inversion_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + gradient_checkpointing: bool = False, **kwargs ): weight_dtype = torch.float32 @@ -207,7 +199,17 @@ def textual_inversion_prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) + unet.to(accelerator.device, dtype=weight_dtype) + unet.requires_grad_(False) + unet.eval() + if gradient_checkpointing: + unet.train() + + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.eval() + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} -- cgit v1.2.3-70-g09d2