from contextlib import nullcontext from typing import Optional from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel from training.functional import TrainingCallbacks, save_samples def textual_inversion_strategy( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], learning_rate: float, gradient_checkpointing: bool = False, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay_factor: float = 1, emb_decay_start: float = 0, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 save_samples_ = partial( save_samples, accelerator=accelerator, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, dtype=weight_dtype, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) if use_ema: ema_embeddings = EMAModel( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) else: ema_embeddings = None def ema_context(): if use_ema: return ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters() ) else: return nullcontext() def on_model(): return text_encoder def on_prepare(): text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) if use_ema: ema_embeddings.to(accelerator.device) if gradient_checkpointing: unet.train() @contextmanager def on_train(epoch: int): tokenizer.train() yield @contextmanager def on_eval(): tokenizer.eval() with ema_context(): yield @torch.no_grad() def on_after_optimize(lr: float): if use_emb_decay: text_encoder.text_model.embeddings.normalize( emb_decay_target, min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) ) if use_ema: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) def on_log(): if use_ema: return {"ema_decay": ema_embeddings.decay} return {} @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") with ema_context(): for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) @torch.no_grad() def on_sample(step): with ema_context(): save_samples_(step=step) return TrainingCallbacks( on_prepare=on_prepare, on_model=on_model, on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, )