diff options
author | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
commit | 21d70916f66e74a87c631a06b70774954b085b48 (patch) | |
tree | d1b443b9270f45ae6936f3acb565f767c7c65b1f /training/strategy/dreambooth.py | |
parent | Run PTI only if placeholder tokens arg isn't empty (diff) | |
download | textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.gz textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.bz2 textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.zip |
Fix
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r-- | training/strategy/dreambooth.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 9808027..0286673 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -84,11 +84,9 @@ def dreambooth_strategy_callbacks( | |||
84 | else: | 84 | else: |
85 | return nullcontext() | 85 | return nullcontext() |
86 | 86 | ||
87 | def on_accum_model(): | ||
88 | return unet | ||
89 | |||
90 | @contextmanager | 87 | @contextmanager |
91 | def on_train(epoch: int): | 88 | def on_train(epoch: int): |
89 | unet.train() | ||
92 | tokenizer.train() | 90 | tokenizer.train() |
93 | 91 | ||
94 | if epoch < train_text_encoder_epochs: | 92 | if epoch < train_text_encoder_epochs: |
@@ -101,6 +99,7 @@ def dreambooth_strategy_callbacks( | |||
101 | 99 | ||
102 | @contextmanager | 100 | @contextmanager |
103 | def on_eval(): | 101 | def on_eval(): |
102 | unet.eval() | ||
104 | tokenizer.eval() | 103 | tokenizer.eval() |
105 | text_encoder.eval() | 104 | text_encoder.eval() |
106 | 105 | ||
@@ -174,7 +173,6 @@ def dreambooth_strategy_callbacks( | |||
174 | torch.cuda.empty_cache() | 173 | torch.cuda.empty_cache() |
175 | 174 | ||
176 | return TrainingCallbacks( | 175 | return TrainingCallbacks( |
177 | on_accum_model=on_accum_model, | ||
178 | on_train=on_train, | 176 | on_train=on_train, |
179 | on_eval=on_eval, | 177 | on_eval=on_eval, |
180 | on_before_optimize=on_before_optimize, | 178 | on_before_optimize=on_before_optimize, |