from typing import Optional from functools import partial from contextlib import contextmanager from pathlib import Path import itertools import json import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from peft import get_peft_model_state_dict from safetensors.torch import save_file from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def lora_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], pti_mode: bool = False, train_text_encoder_cycles: int = 99999, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, max_grad_norm: float = 1.0, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) save_samples_ = partial( save_samples, accelerator=accelerator, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) @contextmanager def on_train(cycle: int): unet.train() if cycle < train_text_encoder_cycles: text_encoder.train() tokenizer.train() yield @contextmanager def on_eval(): unet.eval() text_encoder.eval() tokenizer.eval() yield def on_before_optimize(cycle: int): if not pti_mode: accelerator.clip_grad_norm_( itertools.chain( unet.parameters(), text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), ), max_grad_norm, ) if len(placeholder_tokens) != 0 and use_emb_decay: params = [ p for p in text_encoder.text_model.embeddings.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if w is not None and "emb" in lrs: lr = lrs["emb"] lambda_ = emb_decay * lr if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) w[:].add_( (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) ) @torch.no_grad() def on_checkpoint(step, postfix): if postfix != "end": return print(f"Saving checkpoint for step {step}...") unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) # for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): # text_encoder_.text_model.embeddings.save_embed( # ids, # checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" # ) if not pti_mode: lora_config = {} state_dict = get_peft_model_state_dict( unet_, state_dict=accelerator.get_state_dict(unet_) ) lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) ) text_encoder_state_dict = { f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() } state_dict.update(text_encoder_state_dict) lora_config[ "text_encoder_peft_config" ] = text_encoder_.get_peft_config_as_dict(inference=True) if len(placeholder_tokens) != 0: ti_state_dict = { f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) } state_dict.update(ti_state_dict) save_file( state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" ) with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) del unet_, text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.no_grad() def on_sample(cycle, step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) del unet_, text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def lora_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs, ): ( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, ) = accelerator.prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler lora_strategy = TrainingStrategy( callbacks=lora_strategy_callbacks, prepare=lora_prepare, )