diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/dreambooth.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 35cccbb..dc19ba3 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -31,6 +31,7 @@ def dreambooth_strategy_callbacks( | |||
31 | checkpoint_output_dir: Path, | 31 | checkpoint_output_dir: Path, |
32 | seed: int, | 32 | seed: int, |
33 | train_text_encoder_cycles: int, | 33 | train_text_encoder_cycles: int, |
34 | text_encoder_unfreeze_last_n_layers: int = 2, | ||
34 | max_grad_norm: float = 1.0, | 35 | max_grad_norm: float = 1.0, |
35 | use_ema: bool = False, | 36 | use_ema: bool = False, |
36 | ema_inv_gamma: float = 1.0, | 37 | ema_inv_gamma: float = 1.0, |
@@ -211,7 +212,7 @@ def dreambooth_prepare( | |||
211 | ]: | 212 | ]: |
212 | layer.requires_grad_(False) | 213 | layer.requires_grad_(False) |
213 | 214 | ||
214 | text_encoder.text_model.embeddings.requires_grad_(False) | 215 | # text_encoder.text_model.embeddings.requires_grad_(False) |
215 | 216 | ||
216 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 217 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
217 | 218 | ||