diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/strategy/dreambooth.py | 3 | 
1 files changed, 2 insertions, 1 deletions
| diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 35cccbb..dc19ba3 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -31,6 +31,7 @@ def dreambooth_strategy_callbacks( | |||
| 31 | checkpoint_output_dir: Path, | 31 | checkpoint_output_dir: Path, | 
| 32 | seed: int, | 32 | seed: int, | 
| 33 | train_text_encoder_cycles: int, | 33 | train_text_encoder_cycles: int, | 
| 34 | text_encoder_unfreeze_last_n_layers: int = 2, | ||
| 34 | max_grad_norm: float = 1.0, | 35 | max_grad_norm: float = 1.0, | 
| 35 | use_ema: bool = False, | 36 | use_ema: bool = False, | 
| 36 | ema_inv_gamma: float = 1.0, | 37 | ema_inv_gamma: float = 1.0, | 
| @@ -211,7 +212,7 @@ def dreambooth_prepare( | |||
| 211 | ]: | 212 | ]: | 
| 212 | layer.requires_grad_(False) | 213 | layer.requires_grad_(False) | 
| 213 | 214 | ||
| 214 | text_encoder.text_model.embeddings.requires_grad_(False) | 215 | # text_encoder.text_model.embeddings.requires_grad_(False) | 
| 215 | 216 | ||
| 216 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 217 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 
| 217 | 218 | ||
