diff options
-rw-r--r-- | train_dreambooth.py | 2 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 3 |
2 files changed, 3 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 7745d27..d284346 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -262,7 +262,7 @@ def parse_args(): | |||
262 | ) | 262 | ) |
263 | parser.add_argument( | 263 | parser.add_argument( |
264 | "--text_encoder_unfreeze_last_n_layers", | 264 | "--text_encoder_unfreeze_last_n_layers", |
265 | default=2, | 265 | default=-1, |
266 | help="Number of text encoder layers to train.", | 266 | help="Number of text encoder layers to train.", |
267 | ) | 267 | ) |
268 | parser.add_argument( | 268 | parser.add_argument( |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 35cccbb..dc19ba3 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -31,6 +31,7 @@ def dreambooth_strategy_callbacks( | |||
31 | checkpoint_output_dir: Path, | 31 | checkpoint_output_dir: Path, |
32 | seed: int, | 32 | seed: int, |
33 | train_text_encoder_cycles: int, | 33 | train_text_encoder_cycles: int, |
34 | text_encoder_unfreeze_last_n_layers: int = 2, | ||
34 | max_grad_norm: float = 1.0, | 35 | max_grad_norm: float = 1.0, |
35 | use_ema: bool = False, | 36 | use_ema: bool = False, |
36 | ema_inv_gamma: float = 1.0, | 37 | ema_inv_gamma: float = 1.0, |
@@ -211,7 +212,7 @@ def dreambooth_prepare( | |||
211 | ]: | 212 | ]: |
212 | layer.requires_grad_(False) | 213 | layer.requires_grad_(False) |
213 | 214 | ||
214 | text_encoder.text_model.embeddings.requires_grad_(False) | 215 | # text_encoder.text_model.embeddings.requires_grad_(False) |
215 | 216 | ||
216 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 217 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
217 | 218 | ||