from contextlib import nullcontext from typing import Optional from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import itertools import torch import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples def dreambooth_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, train_text_encoder_epochs: int, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 save_samples_ = partial( save_samples, accelerator=accelerator, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, dtype=weight_dtype, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, num_batches=sample_num_batches, num_steps=sample_num_steps, guidance_scale=sample_guidance_scale, image_size=sample_image_size, ) if use_ema: ema_unet = EMAModel( unet.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, ) else: ema_unet = None def ema_context(): if ema_unet is not None: return ema_unet.apply_temporary(unet.parameters()) else: return nullcontext() def on_model(): return unet def on_prepare(): unet.requires_grad_(True) text_encoder.requires_grad_(True) text_encoder.text_model.embeddings.requires_grad_(False) if ema_unet is not None: ema_unet.to(accelerator.device) @contextmanager def on_train(epoch: int): tokenizer.train() if epoch < train_text_encoder_epochs: text_encoder.train() elif epoch == train_text_encoder_epochs: text_encoder.requires_grad_(False) text_encoder.eval() yield @contextmanager def on_eval(): tokenizer.eval() text_encoder.eval() with ema_context(): yield def on_before_optimize(lr: float, epoch: int): if accelerator.sync_gradients: params_to_clip = [unet.parameters()] if epoch < train_text_encoder_epochs: params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) @torch.no_grad() def on_after_optimize(lr: float): if ema_unet is not None: ema_unet.step(unet.parameters()) def on_log(): if ema_unet is not None: return {"ema_decay": ema_unet.decay} return {} @torch.no_grad() def on_checkpoint(step, postfix): if postfix != "end": return print("Saving model...") unet_ = accelerator.unwrap_model(unet) text_encoder_ = accelerator.unwrap_model(text_encoder) with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, vae=vae, unet=unet_, tokenizer=tokenizer, scheduler=sample_scheduler, ) pipeline.save_pretrained(checkpoint_output_dir) del unet_ del text_encoder_ del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.no_grad() def on_sample(step): with ema_context(): save_samples_(step=step) return TrainingCallbacks( on_prepare=on_prepare, on_model=on_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) def dreambooth_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, *args ): prep = [text_encoder, unet] + list(args) text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler dreambooth_strategy = TrainingStrategy( callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare )