summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 35cccbb..dc19ba3 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -31,6 +31,7 @@ def dreambooth_strategy_callbacks(
31 checkpoint_output_dir: Path, 31 checkpoint_output_dir: Path,
32 seed: int, 32 seed: int,
33 train_text_encoder_cycles: int, 33 train_text_encoder_cycles: int,
34 text_encoder_unfreeze_last_n_layers: int = 2,
34 max_grad_norm: float = 1.0, 35 max_grad_norm: float = 1.0,
35 use_ema: bool = False, 36 use_ema: bool = False,
36 ema_inv_gamma: float = 1.0, 37 ema_inv_gamma: float = 1.0,
@@ -211,7 +212,7 @@ def dreambooth_prepare(
211 ]: 212 ]:
212 layer.requires_grad_(False) 213 layer.requires_grad_(False)
213 214
214 text_encoder.text_model.embeddings.requires_grad_(False) 215 # text_encoder.text_model.embeddings.requires_grad_(False)
215 216
216 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 217 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
217 218