from typing import Optional from types import MethodType from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def textual_inversion_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 save_samples_ = partial( save_samples, accelerator=accelerator, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) if use_ema: ema_embeddings = EMAModel( text_encoder.text_model.embeddings.token_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) ema_embeddings.to(accelerator.device) else: ema_embeddings = None def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.token_embedding.parameters() ) else: return nullcontext() @contextmanager def on_train(cycle: int): text_encoder.train() tokenizer.train() yield @contextmanager def on_eval(): text_encoder.eval() tokenizer.eval() with ema_context(): yield @torch.no_grad() def on_before_optimize(cycle: int): if use_emb_decay: params = [ p for p in text_encoder.text_model.embeddings.token_embedding.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: ema_embeddings.step( text_encoder.text_model.embeddings.token_embedding.parameters() ) if use_emb_decay and w is not None: lr = lrs["emb"] if "emb" in lrs else lrs["0"] lambda_ = emb_decay * lr if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) w[:].add_( (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) ) def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} return {} @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") if postfix == "end": text_encoder_ = accelerator.unwrap_model( text_encoder, keep_fp32_wrapper=False ) text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) with ema_context(): for token, ids in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", ) @torch.no_grad() def on_sample(cycle, step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) orig_unet_dtype = unet_.dtype orig_text_encoder_dtype = text_encoder_.dtype unet_.to(dtype=weight_dtype) text_encoder_.to(dtype=weight_dtype) save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) unet_.to(dtype=orig_unet_dtype) text_encoder_.to(dtype=orig_text_encoder_dtype) del unet_, text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def textual_inversion_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, gradient_checkpointing: bool = False, **kwargs, ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 ( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, ) = accelerator.prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) unet.to(accelerator.device, dtype=weight_dtype) unet.requires_grad_(False) unet.eval() if gradient_checkpointing: unet.train() text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler textual_inversion_strategy = TrainingStrategy( callbacks=textual_inversion_strategy_callbacks, prepare=textual_inversion_prepare, )