from contextlib import nullcontext from typing import Optional from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import torch import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from diffusers.loaders import AttnProcsLayers from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def lora_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, lora_layers: AttnProcsLayers, max_grad_norm: float = 1.0, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 save_samples_ = partial( save_samples, accelerator=accelerator, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) def on_prepare(): lora_layers.requires_grad_(True) def on_accum_model(): return unet @contextmanager def on_train(epoch: int): tokenizer.train() yield @contextmanager def on_eval(): tokenizer.eval() yield def on_before_optimize(lr: float, epoch: int): if accelerator.sync_gradients: accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") orig_unet_dtype = unet.dtype unet.to(dtype=torch.float32) unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) unet.to(dtype=orig_unet_dtype) @torch.no_grad() def on_sample(step): orig_unet_dtype = unet.dtype unet.to(dtype=weight_dtype) save_samples_(step=step) unet.to(dtype=orig_unet_dtype) if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def lora_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, lora_layers: AttnProcsLayers, **kwargs ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} lora_strategy = TrainingStrategy( callbacks=lora_strategy_callbacks, prepare=lora_prepare, )