from contextlib import nullcontext from typing import Optional from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import itertools import torch from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def dreambooth_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, train_text_encoder_epochs: int, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 save_samples_ = partial( save_samples, accelerator=accelerator, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, dtype=weight_dtype, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) if use_ema: ema_unet = EMAModel( unet.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) else: ema_unet = None def ema_context(): if use_ema: return ema_unet.apply_temporary(unet.parameters()) else: return nullcontext() def on_model(): return unet def on_prepare(): unet.requires_grad_(True) text_encoder.requires_grad_(True) text_encoder.text_model.embeddings.persist() text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) if use_ema: ema_unet.to(accelerator.device) @contextmanager def on_train(epoch: int): tokenizer.train() if epoch < train_text_encoder_epochs: text_encoder.train() elif epoch == train_text_encoder_epochs: text_encoder.requires_grad_(False) text_encoder.eval() yield @contextmanager def on_eval(): tokenizer.eval() text_encoder.eval() with ema_context(): yield def on_before_optimize(epoch: int): if accelerator.sync_gradients: params_to_clip = [unet.parameters()] if epoch < train_text_encoder_epochs: params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) @torch.no_grad() def on_after_optimize(lr: float): if use_ema: ema_unet.step(unet.parameters()) def on_log(): if use_ema: return {"ema_decay": ema_unet.decay} return {} @torch.no_grad() def on_checkpoint(step, postfix): if postfix != "end": return print("Saving model...") unet_ = accelerator.unwrap_model(unet) text_encoder_ = accelerator.unwrap_model(text_encoder) with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, vae=vae, unet=unet_, tokenizer=tokenizer, scheduler=sample_scheduler, ) pipeline.save_pretrained(checkpoint_output_dir) del unet_ del text_encoder_ del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.no_grad() def on_sample(step): with ema_context(): save_samples_(step=step) return TrainingCallbacks( on_prepare=on_prepare, on_model=on_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) dreambooth_strategy = TrainingStrategy( callbacks=dreambooth_strategy_callbacks, prepare_unet=True )