from typing import Optional from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import itertools import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def dreambooth_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, train_text_encoder_epochs: int, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 save_samples_ = partial( save_samples, accelerator=accelerator, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) if use_ema: ema_unet = EMAModel( unet.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) ema_unet.to(accelerator.device) else: ema_unet = None def ema_context(): if ema_unet is not None: return ema_unet.apply_temporary(unet.parameters()) else: return nullcontext() @contextmanager def on_train(epoch: int): unet.train() tokenizer.train() if epoch < train_text_encoder_epochs: text_encoder.train() elif epoch == train_text_encoder_epochs: text_encoder.requires_grad_(False) text_encoder.eval() yield @contextmanager def on_eval(): unet.eval() tokenizer.eval() text_encoder.eval() with ema_context(): yield def on_before_optimize(epoch: int): params_to_clip = [unet.parameters()] if epoch < train_text_encoder_epochs: params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) @torch.no_grad() def on_after_optimize(_, lr: float): if ema_unet is not None: ema_unet.step(unet.parameters()) def on_log(): if ema_unet is not None: return {"ema_decay": ema_unet.decay} return {} @torch.no_grad() def on_checkpoint(step, postfix): if postfix != "end": return print("Saving model...") unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, vae=vae, unet=unet_, tokenizer=tokenizer, scheduler=sample_scheduler, ) pipeline.save_pretrained(checkpoint_output_dir) del unet_ del text_encoder_ del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.no_grad() def on_sample(step): with ema_context(): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) orig_unet_dtype = unet_.dtype orig_text_encoder_dtype = text_encoder_.dtype unet_.to(dtype=weight_dtype) text_encoder_.to(dtype=weight_dtype) save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) unet_.to(dtype=orig_unet_dtype) text_encoder_.to(dtype=orig_text_encoder_dtype) del unet_ del text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() return TrainingCallbacks( on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def dreambooth_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) text_encoder.text_model.embeddings.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} dreambooth_strategy = TrainingStrategy( callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare )