from typing import Optional from functools import partial from contextlib import contextmanager from pathlib import Path import itertools import json import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from peft import get_peft_model_state_dict from safetensors.torch import save_file from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def lora_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, max_grad_norm: float = 1.0, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) save_samples_ = partial( save_samples, accelerator=accelerator, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) def on_accum_model(): return unet @contextmanager def on_train(epoch: int): tokenizer.train() text_encoder.train() yield @contextmanager def on_eval(): tokenizer.eval() yield def on_before_optimize(lr: float, epoch: int): accelerator.clip_grad_norm_( itertools.chain(unet.parameters(), text_encoder.parameters()), max_grad_norm ) @torch.no_grad() def on_checkpoint(step, postfix): if postfix != "end": return print(f"Saving checkpoint for step {step}...") unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) lora_config = {} state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) del unet_ del text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.no_grad() def on_sample(step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) del unet_ del text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def lora_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, lora_rank: int = 4, lora_alpha: int = 32, lora_dropout: float = 0, lora_bias: str = "none", **kwargs ): return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) lora_strategy = TrainingStrategy( callbacks=lora_strategy_callbacks, prepare=lora_prepare, )