from typing import Optional from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def textual_inversion_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], gradient_checkpointing: bool = False, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 save_samples_ = partial( save_samples, accelerator=accelerator, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) if use_ema: ema_embeddings = EMAModel( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) ema_embeddings.to(accelerator.device) else: ema_embeddings = None def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters() ) else: return nullcontext() def on_accum_model(): return text_encoder.text_model.embeddings.temp_token_embedding @contextmanager def on_train(epoch: int): tokenizer.train() yield @contextmanager def on_eval(): tokenizer.eval() with ema_context(): yield @torch.no_grad() def on_before_optimize(lr: float, epoch: int): if use_emb_decay: w = text_encoder.text_model.embeddings.temp_token_embedding.weight return torch.all(w.grad == 0, dim=1) @torch.no_grad() def on_after_optimize(zero_ids, lr: float): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_emb_decay: lambda_ = emb_decay * lr if lambda_ != 0: w = text_encoder.text_model.embeddings.temp_token_embedding.weight mask = torch.ones(w.shape[0], dtype=torch.bool) mask[zero_ids] = False norm = w[mask, :].norm(dim=-1, keepdim=True) w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} return {} @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") with ema_context(): for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" ) @torch.no_grad() def on_sample(step): with ema_context(): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) orig_unet_dtype = unet_.dtype orig_text_encoder_dtype = text_encoder_.dtype unet_.to(dtype=weight_dtype) text_encoder_.to(dtype=weight_dtype) save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) unet_.to(dtype=orig_unet_dtype) text_encoder_.to(dtype=orig_text_encoder_dtype) del unet_ del text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def textual_inversion_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, gradient_checkpointing: bool = False, **kwargs ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) unet.to(accelerator.device, dtype=weight_dtype) unet.requires_grad_(False) unet.eval() if gradient_checkpointing: unet.train() text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.eval() return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} textual_inversion_strategy = TrainingStrategy( callbacks=textual_inversion_strategy_callbacks, prepare=textual_inversion_prepare, )