from typing import Optional from types import MethodType from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import itertools import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def dreambooth_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], train_text_encoder_cycles: int, text_encoder_unfreeze_last_n_layers: int = 2, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 save_samples_ = partial( save_samples, accelerator=accelerator, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) if use_ema: ema_unet = EMAModel( unet.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) ema_unet.to(accelerator.device) else: ema_unet = None def ema_context(): if ema_unet is not None: return ema_unet.apply_temporary(unet.parameters()) else: return nullcontext() @contextmanager def on_train(cycle: int): unet.train() tokenizer.train() if cycle < train_text_encoder_cycles: text_encoder.train() yield @contextmanager def on_eval(): unet.eval() tokenizer.eval() text_encoder.eval() with ema_context(): yield def on_before_optimize(cycle: int): params_to_clip = [unet.parameters()] if cycle < train_text_encoder_cycles: params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) if len(placeholder_tokens) != 0 and use_emb_decay: params = [ p for p in text_encoder.text_model.embeddings.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_unet is not None: ema_unet.step(unet.parameters()) if w is not None and "emb" in lrs: lr = lrs["emb"] lambda_ = emb_decay * lr if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) w[:].add_( (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) ) def on_log(): if ema_unet is not None: return {"ema_decay": ema_unet.decay} return {} @torch.no_grad() def on_checkpoint(step, postfix): if postfix != "end": return print("Saving model...") unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) unet_.forward = MethodType(unet_.forward, unet_) text_encoder_.text_model.embeddings.persist(False) with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, vae=vae, unet=unet_, tokenizer=tokenizer, scheduler=sample_scheduler, ) pipeline.save_pretrained(checkpoint_output_dir) del unet_, text_encoder_, pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.no_grad() def on_sample(cycle, step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) orig_unet_dtype = unet_.dtype orig_text_encoder_dtype = text_encoder_.dtype unet_.to(dtype=weight_dtype) text_encoder_.to(dtype=weight_dtype) save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) unet_.to(dtype=orig_unet_dtype) text_encoder_.to(dtype=orig_text_encoder_dtype) del unet_, text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def dreambooth_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, text_encoder_unfreeze_last_n_layers: int = 2, **kwargs ): ( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, ) = accelerator.prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) if text_encoder_unfreeze_last_n_layers == 0: text_encoder.text_model.encoder.requires_grad_(False) elif text_encoder_unfreeze_last_n_layers > 0: for layer in text_encoder.text_model.encoder.layers[ : (-1 * text_encoder_unfreeze_last_n_layers) ]: layer.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) # text_encoder.text_model.embeddings.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler dreambooth_strategy = TrainingStrategy( callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare )