from typing import Optional from functools import partial from contextlib import contextmanager from pathlib import Path import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def textual_inversion_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], gradient_checkpointing: bool = False, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 save_samples_ = partial( save_samples, accelerator=accelerator, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) def on_accum_model(): return text_encoder.text_model.embeddings.overlay @contextmanager def on_train(epoch: int): tokenizer.train() yield @contextmanager def on_eval(): tokenizer.eval() yield @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" ) @torch.no_grad() def on_sample(step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) orig_unet_dtype = unet_.dtype orig_text_encoder_dtype = text_encoder_.dtype unet_.to(dtype=weight_dtype) text_encoder_.to(dtype=weight_dtype) save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) unet_.to(dtype=orig_unet_dtype) text_encoder_.to(dtype=orig_text_encoder_dtype) del unet_ del text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def textual_inversion_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, gradient_checkpointing: bool = False, **kwargs ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) unet.to(accelerator.device, dtype=weight_dtype) unet.requires_grad_(False) unet.eval() if gradient_checkpointing: unet.train() text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.eval() return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} textual_inversion_strategy = TrainingStrategy( callbacks=textual_inversion_strategy_callbacks, prepare=textual_inversion_prepare, )