From 7ccd4614a56cfd6ecacba85605f338593f1059f0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:44:43 +0100 Subject: Add Lora --- training/strategy/dreambooth.py | 35 +++++++--- training/strategy/lora.py | 147 ++++++++++++++++++++++++++++++++++++++++ training/strategy/ti.py | 38 ++++++++--- 3 files changed, 203 insertions(+), 17 deletions(-) create mode 100644 training/strategy/lora.py (limited to 'training/strategy') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e88bf90..b4c77f3 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -61,14 +61,11 @@ def dreambooth_strategy_callbacks( save_samples_ = partial( save_samples, accelerator=accelerator, - unet=unet, - text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=weight_dtype, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, @@ -94,7 +91,7 @@ def dreambooth_strategy_callbacks( else: return nullcontext() - def on_model(): + def on_accum_model(): return unet def on_prepare(): @@ -172,11 +169,29 @@ def dreambooth_strategy_callbacks( @torch.no_grad() def on_sample(step): with ema_context(): - save_samples_(step=step) + unet_ = accelerator.unwrap_model(unet) + text_encoder_ = accelerator.unwrap_model(text_encoder) + + orig_unet_dtype = unet_.dtype + orig_text_encoder_dtype = text_encoder_.dtype + + unet_.to(dtype=weight_dtype) + text_encoder_.to(dtype=weight_dtype) + + save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + + unet_.to(dtype=orig_unet_dtype) + text_encoder_.to(dtype=orig_text_encoder_dtype) + + del unet_ + del text_encoder_ + + if torch.cuda.is_available(): + torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, - on_model=on_model, + on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, @@ -191,9 +206,13 @@ def dreambooth_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, - *args + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + **kwargs ): - return accelerator.prepare(text_encoder, unet, *args) + return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) dreambooth_strategy = TrainingStrategy( diff --git a/training/strategy/lora.py b/training/strategy/lora.py new file mode 100644 index 0000000..88d1824 --- /dev/null +++ b/training/strategy/lora.py @@ -0,0 +1,147 @@ +from contextlib import nullcontext +from typing import Optional +from functools import partial +from contextlib import contextmanager, nullcontext +from pathlib import Path + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from accelerate import Accelerator +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler +from diffusers.loaders import AttnProcsLayers + +from slugify import slugify + +from models.clip.tokenizer import MultiCLIPTokenizer +from training.util import EMAModel +from training.functional import TrainingStrategy, TrainingCallbacks, save_samples + + +def lora_strategy_callbacks( + accelerator: Accelerator, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + tokenizer: MultiCLIPTokenizer, + vae: AutoencoderKL, + sample_scheduler: DPMSolverMultistepScheduler, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + sample_output_dir: Path, + checkpoint_output_dir: Path, + seed: int, + lora_layers: AttnProcsLayers, + max_grad_norm: float = 1.0, + sample_batch_size: int = 1, + sample_num_batches: int = 1, + sample_num_steps: int = 20, + sample_guidance_scale: float = 7.5, + sample_image_size: Optional[int] = None, +): + sample_output_dir.mkdir(parents=True, exist_ok=True) + checkpoint_output_dir.mkdir(parents=True, exist_ok=True) + + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_samples_ = partial( + save_samples, + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + sample_scheduler=sample_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + output_dir=sample_output_dir, + seed=seed, + batch_size=sample_batch_size, + num_batches=sample_num_batches, + num_steps=sample_num_steps, + guidance_scale=sample_guidance_scale, + image_size=sample_image_size, + ) + + def on_prepare(): + lora_layers.requires_grad_(True) + + def on_accum_model(): + return unet + + @contextmanager + def on_train(epoch: int): + tokenizer.train() + yield + + @contextmanager + def on_eval(): + tokenizer.eval() + yield + + def on_before_optimize(lr: float, epoch: int): + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) + + @torch.no_grad() + def on_checkpoint(step, postfix): + print(f"Saving checkpoint for step {step}...") + orig_unet_dtype = unet.dtype + unet.to(dtype=torch.float32) + unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) + unet.to(dtype=orig_unet_dtype) + + @torch.no_grad() + def on_sample(step): + orig_unet_dtype = unet.dtype + unet.to(dtype=weight_dtype) + save_samples_(step=step) + unet.to(dtype=orig_unet_dtype) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return TrainingCallbacks( + on_prepare=on_prepare, + on_accum_model=on_accum_model, + on_train=on_train, + on_eval=on_eval, + on_before_optimize=on_before_optimize, + on_checkpoint=on_checkpoint, + on_sample=on_sample, + ) + + +def lora_prepare( + accelerator: Accelerator, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + lora_layers: AttnProcsLayers, + **kwargs +): + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} + + +lora_strategy = TrainingStrategy( + callbacks=lora_strategy_callbacks, + prepare=lora_prepare, +) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 14bdafd..d306f18 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -59,14 +59,11 @@ def textual_inversion_strategy_callbacks( save_samples_ = partial( save_samples, accelerator=accelerator, - unet=unet, - text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=weight_dtype, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, @@ -94,7 +91,7 @@ def textual_inversion_strategy_callbacks( else: return nullcontext() - def on_model(): + def on_accum_model(): return text_encoder.text_model.embeddings.temp_token_embedding def on_prepare(): @@ -149,11 +146,29 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_sample(step): with ema_context(): - save_samples_(step=step) + unet_ = accelerator.unwrap_model(unet) + text_encoder_ = accelerator.unwrap_model(text_encoder) + + orig_unet_dtype = unet_.dtype + orig_text_encoder_dtype = text_encoder_.dtype + + unet_.to(dtype=weight_dtype) + text_encoder_.to(dtype=weight_dtype) + + save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + + unet_.to(dtype=orig_unet_dtype) + text_encoder_.to(dtype=orig_text_encoder_dtype) + + del unet_ + del text_encoder_ + + if torch.cuda.is_available(): + torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, - on_model=on_model, + on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, @@ -168,7 +183,11 @@ def textual_inversion_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, - *args + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + **kwargs ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": @@ -176,9 +195,10 @@ def textual_inversion_prepare( elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - prepped = accelerator.prepare(text_encoder, *args) + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) unet.to(accelerator.device, dtype=weight_dtype) - return (prepped[0], unet) + prepped[1:] + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} textual_inversion_strategy = TrainingStrategy( -- cgit v1.2.3-70-g09d2