From 95adaea8b55d8e3755c035758bc649ae22548572 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 24 Mar 2023 10:53:16 +0100 Subject: Refactoring, fixed Lora training --- training/strategy/lora.py | 49 +++++++++++------------------------------------ 1 file changed, 11 insertions(+), 38 deletions(-) (limited to 'training/strategy/lora.py') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1c8fad6..3971eae 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -10,18 +10,12 @@ from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler -from peft import LoraConfig, LoraModel, get_peft_model_state_dict -from peft.tuners.lora import mark_only_lora_as_trainable +from peft import get_peft_model_state_dict from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py -UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] -TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] - - def lora_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, @@ -61,10 +55,6 @@ def lora_strategy_callbacks( image_size=sample_image_size, ) - def on_prepare(): - mark_only_lora_as_trainable(unet.model, unet.peft_config.bias) - mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias) - def on_accum_model(): return unet @@ -93,15 +83,15 @@ def lora_strategy_callbacks( text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) lora_config = {} - state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) - lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) + state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) + lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( - text_encoder, state_dict=accelerator.get_state_dict(text_encoder) + text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) + lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) accelerator.print(state_dict) accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") @@ -111,11 +101,16 @@ def lora_strategy_callbacks( @torch.no_grad() def on_sample(step): + vae_dtype = vae.dtype + vae.to(dtype=text_encoder.dtype) + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + vae.to(dtype=vae_dtype) + del unet_ del text_encoder_ @@ -123,7 +118,6 @@ def lora_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, @@ -147,28 +141,7 @@ def lora_prepare( lora_bias: str = "none", **kwargs ): - unet_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - target_modules=UNET_TARGET_MODULES, - lora_dropout=lora_dropout, - bias=lora_bias, - ) - unet = LoraModel(unet_config, unet) - - text_encoder_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - target_modules=TEXT_ENCODER_TARGET_MODULES, - lora_dropout=lora_dropout, - bias=lora_bias, - ) - text_encoder = LoraModel(text_encoder_config, text_encoder) - - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) - - return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} + return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) lora_strategy = TrainingStrategy( -- cgit v1.2.3-54-g00ecf