from contextlib import nullcontext from typing import Optional from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel from training.functional import save_samples def textual_inversion_strategy( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: DataLoader, dtype: torch.dtype, output_dir: Path, seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], learning_rate: float, gradient_checkpointing: bool = False, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay_factor: float = 1, emb_decay_start: float = 1e-4, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): save_samples_ = partial( save_samples, accelerator=accelerator, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, dtype=dtype, output_dir=output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) if use_ema: ema_embeddings = EMAModel( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) else: ema_embeddings = None def on_prepare(): text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) if use_ema: ema_embeddings.to(accelerator.device) if gradient_checkpointing: unet.train() @contextmanager def on_train(epoch: int): try: tokenizer.train() yield finally: pass @contextmanager def on_eval(): try: tokenizer.eval() ema_context = ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext() with ema_context: yield finally: pass @torch.no_grad() def on_after_optimize(lr: float): if use_emb_decay: text_encoder.text_model.embeddings.normalize( emb_decay_target, min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) ) if use_ema: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) def on_log(): if use_ema: return {"ema_decay": ema_embeddings.decay} return {} @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") checkpoints_path = output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) text_encoder = accelerator.unwrap_model(text_encoder) ema_context = ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters() ) if ema_embeddings is not None else nullcontext() with ema_context: for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) @torch.no_grad() def on_sample(step): ema_context = ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters() ) if ema_embeddings is not None else nullcontext() with ema_context: save_samples_(step=step) return { "on_prepare": on_prepare, "on_train": on_train, "on_eval": on_eval, "on_after_optimize": on_after_optimize, "on_log": on_log, "on_checkpoint": on_checkpoint, "on_sample": on_sample, }