summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-01 12:34:42 +0100
committerVolpeon <git@volpeon.ink>2023-03-01 12:34:42 +0100
commita1b8327085ddeab589be074d7e9df4291aba1210 (patch)
tree2f2016916d7a2f659268c3e375d55c59583c2b3b /training/strategy/dreambooth.py
parentFixed TI normalization order (diff)
downloadtextual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz
textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2
textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip
Update
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 0290327..e5e84c8 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -88,8 +88,8 @@ def dreambooth_strategy_callbacks(
88 88
89 def on_prepare(): 89 def on_prepare():
90 unet.requires_grad_(True) 90 unet.requires_grad_(True)
91 text_encoder.requires_grad_(True) 91 text_encoder.text_model.encoder.requires_grad_(True)
92 text_encoder.text_model.embeddings.requires_grad_(False) 92 text_encoder.text_model.final_layer_norm.requires_grad_(True)
93 93
94 if ema_unet is not None: 94 if ema_unet is not None:
95 ema_unet.to(accelerator.device) 95 ema_unet.to(accelerator.device)
@@ -203,7 +203,7 @@ def dreambooth_prepare(
203 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 203 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
204 **kwargs 204 **kwargs
205): 205):
206 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) 206 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},)
207 207
208 208
209dreambooth_strategy = TrainingStrategy( 209dreambooth_strategy = TrainingStrategy(