from typing import Optional from functools import partial from contextlib import contextmanager from pathlib import Path import itertools import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from peft import LoraConfig, LoraModel, get_peft_model_state_dict from peft.tuners.lora import mark_only_lora_as_trainable from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] def lora_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, max_grad_norm: float = 1.0, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) save_samples_ = partial( save_samples, accelerator=accelerator, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) def on_prepare(): mark_only_lora_as_trainable(unet.model, unet.peft_config.bias) mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias) def on_accum_model(): return unet @contextmanager def on_train(epoch: int): tokenizer.train() yield @contextmanager def on_eval(): tokenizer.eval() yield def on_before_optimize(lr: float, epoch: int): accelerator.clip_grad_norm_( itertools.chain(unet.parameters(), text_encoder.parameters()), max_grad_norm ) @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") unet_ = accelerator.unwrap_model(unet, False) text_encoder_ = accelerator.unwrap_model(text_encoder, False) lora_config = {} state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( text_encoder, state_dict=accelerator.get_state_dict(text_encoder) ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) accelerator.print(state_dict) accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") del unet_ del text_encoder_ @torch.no_grad() def on_sample(step): unet_ = accelerator.unwrap_model(unet, False) text_encoder_ = accelerator.unwrap_model(text_encoder, False) save_samples_(step=step, unet=unet_) del unet_ del text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def lora_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, lora_rank: int = 4, lora_alpha: int = 32, lora_dropout: float = 0, lora_bias: str = "none", **kwargs ): unet_config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, target_modules=UNET_TARGET_MODULES, lora_dropout=lora_dropout, bias=lora_bias, ) unet = LoraModel(unet_config, unet) text_encoder_config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, target_modules=TEXT_ENCODER_TARGET_MODULES, lora_dropout=lora_dropout, bias=lora_bias, ) text_encoder = LoraModel(text_encoder_config, text_encoder) text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} lora_strategy = TrainingStrategy( callbacks=lora_strategy_callbacks, prepare=lora_prepare, )