diff options
Diffstat (limited to 'training/strategy/dreambooth.py')
| -rw-r--r-- | training/strategy/dreambooth.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index d813b49..f57e736 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -99,8 +99,7 @@ def dreambooth_strategy_callbacks( | |||
| 99 | def on_prepare(): | 99 | def on_prepare(): |
| 100 | unet.requires_grad_(True) | 100 | unet.requires_grad_(True) |
| 101 | text_encoder.requires_grad_(True) | 101 | text_encoder.requires_grad_(True) |
| 102 | text_encoder.text_model.embeddings.persist() | 102 | text_encoder.text_model.embeddings.requires_grad_(False) |
| 103 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | ||
| 104 | 103 | ||
| 105 | if ema_unet is not None: | 104 | if ema_unet is not None: |
| 106 | ema_unet.to(accelerator.device) | 105 | ema_unet.to(accelerator.device) |
| @@ -125,7 +124,7 @@ def dreambooth_strategy_callbacks( | |||
| 125 | with ema_context(): | 124 | with ema_context(): |
| 126 | yield | 125 | yield |
| 127 | 126 | ||
| 128 | def on_before_optimize(epoch: int): | 127 | def on_before_optimize(lr: float, epoch: int): |
| 129 | if accelerator.sync_gradients: | 128 | if accelerator.sync_gradients: |
| 130 | params_to_clip = [unet.parameters()] | 129 | params_to_clip = [unet.parameters()] |
| 131 | if epoch < train_text_encoder_epochs: | 130 | if epoch < train_text_encoder_epochs: |
