summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 9808027..0286673 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -84,11 +84,9 @@ def dreambooth_strategy_callbacks(
84 else: 84 else:
85 return nullcontext() 85 return nullcontext()
86 86
87 def on_accum_model():
88 return unet
89
90 @contextmanager 87 @contextmanager
91 def on_train(epoch: int): 88 def on_train(epoch: int):
89 unet.train()
92 tokenizer.train() 90 tokenizer.train()
93 91
94 if epoch < train_text_encoder_epochs: 92 if epoch < train_text_encoder_epochs:
@@ -101,6 +99,7 @@ def dreambooth_strategy_callbacks(
101 99
102 @contextmanager 100 @contextmanager
103 def on_eval(): 101 def on_eval():
102 unet.eval()
104 tokenizer.eval() 103 tokenizer.eval()
105 text_encoder.eval() 104 text_encoder.eval()
106 105
@@ -174,7 +173,6 @@ def dreambooth_strategy_callbacks(
174 torch.cuda.empty_cache() 173 torch.cuda.empty_cache()
175 174
176 return TrainingCallbacks( 175 return TrainingCallbacks(
177 on_accum_model=on_accum_model,
178 on_train=on_train, 176 on_train=on_train,
179 on_eval=on_eval, 177 on_eval=on_eval,
180 on_before_optimize=on_before_optimize, 178 on_before_optimize=on_before_optimize,