summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-22 11:22:28 +0200
committerVolpeon <git@volpeon.ink>2023-06-22 11:22:28 +0200
commit15080055bf4330a806c409d3ca69ec5b0eab99f2 (patch)
tree8182e8d8dec1f3345b62bfa24c28b4380fb482d4 /training/strategy/dreambooth.py
parentRemove training guidance_scale (diff)
downloadtextual-inversion-diff-15080055bf4330a806c409d3ca69ec5b0eab99f2.tar.gz
textual-inversion-diff-15080055bf4330a806c409d3ca69ec5b0eab99f2.tar.bz2
textual-inversion-diff-15080055bf4330a806c409d3ca69ec5b0eab99f2.zip
Update
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 43fe838..35cccbb 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -203,10 +203,13 @@ def dreambooth_prepare(
203 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 203 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
204 ) 204 )
205 205
206 for layer in text_encoder.text_model.encoder.layers[ 206 if text_encoder_unfreeze_last_n_layers == 0:
207 : (-1 * text_encoder_unfreeze_last_n_layers) 207 text_encoder.text_model.encoder.requires_grad_(False)
208 ]: 208 elif text_encoder_unfreeze_last_n_layers > 0:
209 layer.requires_grad_(False) 209 for layer in text_encoder.text_model.encoder.layers[
210 : (-1 * text_encoder_unfreeze_last_n_layers)
211 ]:
212 layer.requires_grad_(False)
210 213
211 text_encoder.text_model.embeddings.requires_grad_(False) 214 text_encoder.text_model.embeddings.requires_grad_(False)
212 215