From 95adaea8b55d8e3755c035758bc649ae22548572 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 24 Mar 2023 10:53:16 +0100 Subject: Refactoring, fixed Lora training --- training/functional.py | 9 ++------ training/strategy/dreambooth.py | 17 ++++++-------- training/strategy/lora.py | 49 +++++++++-------------------------------- training/strategy/ti.py | 22 +++++++++--------- 4 files changed, 32 insertions(+), 65 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index a5b339d..ee73ab2 100644 --- a/training/functional.py +++ b/training/functional.py @@ -34,7 +34,6 @@ def const(result=None): @dataclass class TrainingCallbacks(): - on_prepare: Callable[[], None] = const() on_accum_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) @@ -620,10 +619,8 @@ def train( kwargs.update(extra) vae.to(accelerator.device, dtype=dtype) - - for model in (unet, text_encoder, vae): - model.requires_grad_(False) - model.eval() + vae.requires_grad_(False) + vae.eval() callbacks = strategy.callbacks( accelerator=accelerator, @@ -636,8 +633,6 @@ def train( **kwargs, ) - callbacks.on_prepare() - loss_step_ = partial( loss_step, vae, diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 28fccff..9808027 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -74,6 +74,7 @@ def dreambooth_strategy_callbacks( power=ema_power, max_value=ema_max_decay, ) + ema_unet.to(accelerator.device) else: ema_unet = None @@ -86,14 +87,6 @@ def dreambooth_strategy_callbacks( def on_accum_model(): return unet - def on_prepare(): - unet.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(True) - text_encoder.text_model.final_layer_norm.requires_grad_(True) - - if ema_unet is not None: - ema_unet.to(accelerator.device) - @contextmanager def on_train(epoch: int): tokenizer.train() @@ -181,7 +174,6 @@ def dreambooth_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, @@ -203,7 +195,12 @@ def dreambooth_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + + text_encoder.text_model.embeddings.requires_grad_(False) + + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} dreambooth_strategy = TrainingStrategy( diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1c8fad6..3971eae 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -10,18 +10,12 @@ from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler -from peft import LoraConfig, LoraModel, get_peft_model_state_dict -from peft.tuners.lora import mark_only_lora_as_trainable +from peft import get_peft_model_state_dict from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py -UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] -TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] - - def lora_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, @@ -61,10 +55,6 @@ def lora_strategy_callbacks( image_size=sample_image_size, ) - def on_prepare(): - mark_only_lora_as_trainable(unet.model, unet.peft_config.bias) - mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias) - def on_accum_model(): return unet @@ -93,15 +83,15 @@ def lora_strategy_callbacks( text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) lora_config = {} - state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) - lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) + state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) + lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( - text_encoder, state_dict=accelerator.get_state_dict(text_encoder) + text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) + lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) accelerator.print(state_dict) accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") @@ -111,11 +101,16 @@ def lora_strategy_callbacks( @torch.no_grad() def on_sample(step): + vae_dtype = vae.dtype + vae.to(dtype=text_encoder.dtype) + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + vae.to(dtype=vae_dtype) + del unet_ del text_encoder_ @@ -123,7 +118,6 @@ def lora_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, @@ -147,28 +141,7 @@ def lora_prepare( lora_bias: str = "none", **kwargs ): - unet_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - target_modules=UNET_TARGET_MODULES, - lora_dropout=lora_dropout, - bias=lora_bias, - ) - unet = LoraModel(unet_config, unet) - - text_encoder_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - target_modules=TEXT_ENCODER_TARGET_MODULES, - lora_dropout=lora_dropout, - bias=lora_bias, - ) - text_encoder = LoraModel(text_encoder_config, text_encoder) - - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) - - return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} + return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) lora_strategy = TrainingStrategy( diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 2038e34..10bc6d7 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -78,6 +78,7 @@ def textual_inversion_strategy_callbacks( power=ema_power, max_value=ema_max_decay, ) + ema_embeddings.to(accelerator.device) else: ema_embeddings = None @@ -92,15 +93,6 @@ def textual_inversion_strategy_callbacks( def on_accum_model(): return text_encoder.text_model.embeddings.temp_token_embedding - def on_prepare(): - text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) - - if ema_embeddings is not None: - ema_embeddings.to(accelerator.device) - - if gradient_checkpointing: - unet.train() - @contextmanager def on_train(epoch: int): tokenizer.train() @@ -177,7 +169,6 @@ def textual_inversion_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, @@ -197,6 +188,7 @@ def textual_inversion_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + gradient_checkpointing: bool = False, **kwargs ): weight_dtype = torch.float32 @@ -207,7 +199,17 @@ def textual_inversion_prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) + unet.to(accelerator.device, dtype=weight_dtype) + unet.requires_grad_(False) + unet.eval() + if gradient_checkpointing: + unet.train() + + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.eval() + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} -- cgit v1.2.3-54-g00ecf