From 3f922880475c2c0a5679987d4a9a43606e838566 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 22:26:43 +0100 Subject: Added Dreambooth strategy --- train_ti.py | 46 +++++----- training/strategy/dreambooth.py | 183 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+), 23 deletions(-) create mode 100644 training/strategy/dreambooth.py diff --git a/train_ti.py b/train_ti.py index 77dec12..2497519 100644 --- a/train_ti.py +++ b/train_ti.py @@ -557,15 +557,6 @@ def main(): else: optimizer_class = torch.optim.AdamW - optimizer = optimizer_class( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - amsgrad=args.adam_amsgrad, - ) - weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -624,6 +615,29 @@ def main(): args.sample_steps ) + trainer = partial( + train, + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + vae=vae, + noise_scheduler=noise_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + dtype=weight_dtype, + seed=args.seed, + callbacks_fn=textual_inversion_strategy + ) + + optimizer = optimizer_class( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + if args.find_lr: lr_scheduler = None else: @@ -642,20 +656,6 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - trainer = partial( - train, - accelerator=accelerator, - unet=unet, - text_encoder=text_encoder, - vae=vae, - noise_scheduler=noise_scheduler, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - dtype=weight_dtype, - seed=args.seed, - callbacks_fn=textual_inversion_strategy - ) - trainer( optimizer=optimizer, lr_scheduler=lr_scheduler, diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py new file mode 100644 index 0000000..6e7ebe2 --- /dev/null +++ b/training/strategy/dreambooth.py @@ -0,0 +1,183 @@ +from contextlib import nullcontext +from typing import Optional +from functools import partial +from contextlib import contextmanager, nullcontext +from pathlib import Path +import itertools + +import torch +from torch.utils.data import DataLoader + +from accelerate import Accelerator +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler + +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.tokenizer import MultiCLIPTokenizer +from training.util import EMAModel +from training.functional import TrainingCallbacks, save_samples + + +def dreambooth_strategy( + accelerator: Accelerator, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + tokenizer: MultiCLIPTokenizer, + vae: AutoencoderKL, + sample_scheduler: DPMSolverMultistepScheduler, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + output_dir: Path, + seed: int, + train_text_encoder_epochs: int, + max_grad_norm: float = 1.0, + use_ema: bool = False, + ema_inv_gamma: float = 1.0, + ema_power: int = 1, + ema_max_decay: float = 0.9999, + sample_batch_size: int = 1, + sample_num_batches: int = 1, + sample_num_steps: int = 20, + sample_guidance_scale: float = 7.5, + sample_image_size: Optional[int] = None, +): + if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_samples_ = partial( + save_samples, + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + sample_scheduler=sample_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + dtype=weight_dtype, + output_dir=output_dir, + seed=seed, + batch_size=sample_batch_size, + num_batches=sample_num_batches, + num_steps=sample_num_steps, + guidance_scale=sample_guidance_scale, + image_size=sample_image_size, + ) + + if use_ema: + ema_unet = EMAModel( + unet.parameters(), + inv_gamma=ema_inv_gamma, + power=ema_power, + max_value=ema_max_decay, + ) + else: + ema_unet = None + + def ema_context(): + if use_ema: + return ema_unet.apply_temporary(unet.parameters()) + else: + return nullcontext() + + def on_model(): + return unet + + def on_prepare(): + unet.requires_grad_(True) + text_encoder.requires_grad_(True) + text_encoder.text_model.embeddings.persist() + text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) + + if use_ema: + ema_unet.to(accelerator.device) + + @contextmanager + def on_train(epoch: int): + tokenizer.train() + + if epoch < train_text_encoder_epochs: + text_encoder.train() + elif epoch == train_text_encoder_epochs: + text_encoder.requires_grad_(False) + text_encoder.eval() + + yield + + @contextmanager + def on_eval(): + tokenizer.eval() + text_encoder.eval() + + with ema_context(): + yield + + def on_before_optimize(epoch: int): + if accelerator.sync_gradients: + params_to_clip = [unet.parameters()] + if epoch < train_text_encoder_epochs: + params_to_clip.append(text_encoder.parameters()) + accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) + + @torch.no_grad() + def on_after_optimize(lr: float): + if use_ema: + ema_unet.step(unet.parameters()) + + def on_log(): + if use_ema: + return {"ema_decay": ema_unet.decay} + return {} + + @torch.no_grad() + def on_checkpoint(step, postfix): + if postfix != "end": + return + + print("Saving model...") + + unet_ = accelerator.unwrap_model(unet) + text_encoder_ = accelerator.unwrap_model(text_encoder) + + with ema_context(): + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder_, + vae=vae, + unet=unet_, + tokenizer=tokenizer, + scheduler=sample_scheduler, + ) + pipeline.save_pretrained(output_dir.joinpath("model")) + + del unet_ + del text_encoder_ + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @torch.no_grad() + def on_sample(step): + with ema_context(): + save_samples_(step=step) + + return TrainingCallbacks( + on_prepare=on_prepare, + on_model=on_model, + on_train=on_train, + on_eval=on_eval, + on_before_optimize=on_before_optimize, + on_after_optimize=on_after_optimize, + on_log=on_log, + on_checkpoint=on_checkpoint, + on_sample=on_sample, + ) -- cgit v1.2.3-70-g09d2