From 5821523a524190490a287c5e2aacb6e72cc3a4cf Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 07:20:45 +0100 Subject: Update --- training/strategy/dreambooth.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'training/strategy/dreambooth.py') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 93c81cb..bc26ee6 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel -from training.functional import TrainingCallbacks, save_samples +from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -def dreambooth_strategy( +def dreambooth_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, @@ -185,3 +185,9 @@ def dreambooth_strategy( on_checkpoint=on_checkpoint, on_sample=on_sample, ) + + +dreambooth_strategy = TrainingStrategy( + callbacks=dreambooth_strategy_callbacks, + prepare_unet=True +) -- cgit v1.2.3-54-g00ecf