From f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Mar 2023 13:46:36 +0100 Subject: Fixed SNR weighting, re-enabled xformers --- environment.yaml | 2 ++ train_lora.py | 36 ++++++------------------ train_ti.py | 4 +-- training/functional.py | 35 +++++++++++++++++------- training/strategy/lora.py | 70 +++++++++++++++++++++++++++++++++++++++-------- 5 files changed, 97 insertions(+), 50 deletions(-) diff --git a/environment.yaml b/environment.yaml index 9355f37..db43bd5 100644 --- a/environment.yaml +++ b/environment.yaml @@ -17,9 +17,11 @@ dependencies: - -e git+https://github.com/huggingface/diffusers#egg=diffusers - accelerate==0.17.1 - bitsandbytes==0.37.1 + - peft==0.2.0 - python-slugify>=6.1.2 - safetensors==0.3.0 - setuptools==65.6.3 - test-tube>=0.7.5 - transformers==4.27.1 - triton==2.0.0 + - xformers==0.0.17.dev480 diff --git a/train_lora.py b/train_lora.py index e65e7be..2a798f3 100644 --- a/train_lora.py +++ b/train_lora.py @@ -12,8 +12,6 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify -from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRACrossAttnProcessor from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -426,34 +424,16 @@ def main(): tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) - unet.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) - - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=args.lora_rank - ) - - unet.set_attn_processor(lora_attn_procs) + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() - lora_layers = AttnProcsLayers(unet.attn_processors) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) @@ -505,7 +485,6 @@ def main(): unet=unet, text_encoder=text_encoder, vae=vae, - lora_layers=lora_layers, noise_scheduler=noise_scheduler, dtype=weight_dtype, with_prior_preservation=args.num_class_images != 0, @@ -540,7 +519,10 @@ def main(): datamodule.setup() optimizer = create_optimizer( - lora_layers.parameters(), + itertools.chain( + unet.parameters(), + text_encoder.parameters(), + ), lr=args.learning_rate, ) diff --git a/train_ti.py b/train_ti.py index fd23517..2e92ae4 100644 --- a/train_ti.py +++ b/train_ti.py @@ -547,8 +547,8 @@ def main(): tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() - # vae.set_use_memory_efficient_attention_xformers(True) - # unet.enable_xformers_memory_efficient_attention() + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() # unet = torch.compile(unet) if args.gradient_checkpointing: diff --git a/training/functional.py b/training/functional.py index 8dc2b9f..43ee356 100644 --- a/training/functional.py +++ b/training/functional.py @@ -251,6 +251,25 @@ def add_placeholder_tokens( return placeholder_token_ids, initializer_token_ids +def snr_weight(noisy_latents, latents, gamma): + if gamma: + sigma = torch.sub(noisy_latents, latents) + zeros = torch.zeros_like(sigma) + alpha_mean_sq = F.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) + sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) + snr = torch.div(alpha_mean_sq, sigma_mean_sq) + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() + return snr_weight + + return torch.tensor( + [1], + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + ) + + def loss_step( vae: AutoencoderKL, noise_scheduler: SchedulerMixin, @@ -308,21 +327,13 @@ def loss_step( model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type - alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() - snr = alpha_t / (1 - alpha_t) - min_snr = snr.clamp(max=min_snr_gamma) - if noise_scheduler.config.prediction_type == "epsilon": target = noise - loss_weight = min_snr / snr elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) - loss_weight = min_snr / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - loss_weight = loss_weight[..., None, None, None] - if with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) @@ -339,7 +350,11 @@ def loss_step( else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = (loss_weight * loss).mean([1, 2, 3]).mean() + loss = loss.mean([1, 2, 3]) + + loss_weight = snr_weight(noisy_latents, latents, min_snr_gamma) + loss = (loss_weight * loss).mean() + acc = (model_pred == target).float().mean() return loss, acc, bsz @@ -412,7 +427,7 @@ def train_loop( try: for epoch in range(num_epochs): if accelerator.is_main_process: - if epoch % sample_frequency == 0 and epoch != 0: + if epoch % sample_frequency == 0: local_progress_bar.clear() global_progress_bar.clear() diff --git a/training/strategy/lora.py b/training/strategy/lora.py index cab5e4c..aa75bec 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -2,6 +2,7 @@ from typing import Optional from functools import partial from contextlib import contextmanager from pathlib import Path +import itertools import torch from torch.utils.data import DataLoader @@ -9,12 +10,18 @@ from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler -from diffusers.loaders import AttnProcsLayers +from peft import LoraConfig, LoraModel, get_peft_model_state_dict +from peft.tuners.lora import mark_only_lora_as_trainable from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples +# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py +UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] + + def lora_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, @@ -27,7 +34,6 @@ def lora_strategy_callbacks( sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, - lora_layers: AttnProcsLayers, max_grad_norm: float = 1.0, sample_batch_size: int = 1, sample_num_batches: int = 1, @@ -57,7 +63,8 @@ def lora_strategy_callbacks( ) def on_prepare(): - lora_layers.requires_grad_(True) + mark_only_lora_as_trainable(unet.model, unet.peft_config.bias) + mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias) def on_accum_model(): return unet @@ -73,24 +80,44 @@ def lora_strategy_callbacks( yield def on_before_optimize(lr: float, epoch: int): - accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) + accelerator.clip_grad_norm_( + itertools.chain(unet.parameters(), text_encoder.parameters()), + max_grad_norm + ) @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") unet_ = accelerator.unwrap_model(unet, False) - unet_.save_attn_procs( - checkpoint_output_dir / f"{step}_{postfix}", - safe_serialization=True + text_encoder_ = accelerator.unwrap_model(text_encoder, False) + + lora_config = {} + state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) + lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) + + text_encoder_state_dict = get_peft_model_state_dict( + text_encoder, state_dict=accelerator.get_state_dict(text_encoder) ) + text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} + state_dict.update(text_encoder_state_dict) + lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) + + accelerator.print(state_dict) + accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") + del unet_ + del text_encoder_ @torch.no_grad() def on_sample(step): unet_ = accelerator.unwrap_model(unet, False) + text_encoder_ = accelerator.unwrap_model(text_encoder, False) + save_samples_(step=step, unet=unet_) + del unet_ + del text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -114,13 +141,34 @@ def lora_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - lora_layers: AttnProcsLayers, + lora_rank: int = 4, + lora_alpha: int = 32, + lora_dropout: float = 0, + lora_bias: str = "none", **kwargs ): - lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) + unet_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=UNET_TARGET_MODULES, + lora_dropout=lora_dropout, + bias=lora_bias, + ) + unet = LoraModel(unet_config, unet) + + text_encoder_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + lora_dropout=lora_dropout, + bias=lora_bias, + ) + text_encoder = LoraModel(text_encoder_config, text_encoder) + + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) - return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} lora_strategy = TrainingStrategy( -- cgit v1.2.3-54-g00ecf