from typing import Optional from functools import partial from contextlib import contextmanager from pathlib import Path import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from diffusers.loaders import AttnProcsLayers from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def lora_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, lora_layers: AttnProcsLayers, max_grad_norm: float = 1.0, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) save_samples_ = partial( save_samples, accelerator=accelerator, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) def on_prepare(): lora_layers.requires_grad_(True) def on_accum_model(): return unet @contextmanager def on_train(epoch: int): tokenizer.train() yield @contextmanager def on_eval(): tokenizer.eval() yield def on_before_optimize(lr: float, epoch: int): accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") unet_ = accelerator.unwrap_model(unet, False) unet_.save_attn_procs( checkpoint_output_dir / f"{step}_{postfix}", safe_serialization=True ) del unet_ @torch.no_grad() def on_sample(step): unet_ = accelerator.unwrap_model(unet, False) save_samples_(step=step, unet=unet_) del unet_ if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def lora_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, lora_layers: AttnProcsLayers, **kwargs ): lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} lora_strategy = TrainingStrategy( callbacks=lora_strategy_callbacks, prepare=lora_prepare, )