diff options
author | Volpeon <git@volpeon.ink> | 2023-01-17 07:20:45 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-17 07:20:45 +0100 |
commit | 5821523a524190490a287c5e2aacb6e72cc3a4cf (patch) | |
tree | c0eac536c754f078683be6d59893ad23d70baf51 /training/strategy/dreambooth.py | |
parent | Training update (diff) | |
download | textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.gz textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.bz2 textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.zip |
Update
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r-- | training/strategy/dreambooth.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 93c81cb..bc26ee6 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
18 | from training.functional import TrainingCallbacks, save_samples | 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
19 | 19 | ||
20 | 20 | ||
21 | def dreambooth_strategy( | 21 | def dreambooth_strategy_callbacks( |
22 | accelerator: Accelerator, | 22 | accelerator: Accelerator, |
23 | unet: UNet2DConditionModel, | 23 | unet: UNet2DConditionModel, |
24 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
@@ -185,3 +185,9 @@ def dreambooth_strategy( | |||
185 | on_checkpoint=on_checkpoint, | 185 | on_checkpoint=on_checkpoint, |
186 | on_sample=on_sample, | 186 | on_sample=on_sample, |
187 | ) | 187 | ) |
188 | |||
189 | |||
190 | dreambooth_strategy = TrainingStrategy( | ||
191 | callbacks=dreambooth_strategy_callbacks, | ||
192 | prepare_unet=True | ||
193 | ) | ||