summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 0290327..e5e84c8 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -88,8 +88,8 @@ def dreambooth_strategy_callbacks(
88 88
89 def on_prepare(): 89 def on_prepare():
90 unet.requires_grad_(True) 90 unet.requires_grad_(True)
91 text_encoder.requires_grad_(True) 91 text_encoder.text_model.encoder.requires_grad_(True)
92 text_encoder.text_model.embeddings.requires_grad_(False) 92 text_encoder.text_model.final_layer_norm.requires_grad_(True)
93 93
94 if ema_unet is not None: 94 if ema_unet is not None:
95 ema_unet.to(accelerator.device) 95 ema_unet.to(accelerator.device)
@@ -203,7 +203,7 @@ def dreambooth_prepare(
203 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 203 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
204 **kwargs 204 **kwargs
205): 205):
206 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) 206 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},)
207 207
208 208
209dreambooth_strategy = TrainingStrategy( 209dreambooth_strategy = TrainingStrategy(