summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 93c81cb..bc26ee6 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import TrainingCallbacks, save_samples 18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 19
20 20
21def dreambooth_strategy( 21def dreambooth_strategy_callbacks(
22 accelerator: Accelerator, 22 accelerator: Accelerator,
23 unet: UNet2DConditionModel, 23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel, 24 text_encoder: CLIPTextModel,
@@ -185,3 +185,9 @@ def dreambooth_strategy(
185 on_checkpoint=on_checkpoint, 185 on_checkpoint=on_checkpoint,
186 on_sample=on_sample, 186 on_sample=on_sample,
187 ) 187 )
188
189
190dreambooth_strategy = TrainingStrategy(
191 callbacks=dreambooth_strategy_callbacks,
192 prepare_unet=True
193)