diff options
Diffstat (limited to 'training/strategy/dreambooth.py')
| -rw-r--r-- | training/strategy/dreambooth.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( | |||
| 106 | with ema_context(): | 106 | with ema_context(): |
| 107 | yield | 107 | yield |
| 108 | 108 | ||
| 109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(epoch: int): |
| 110 | params_to_clip = [unet.parameters()] | 110 | params_to_clip = [unet.parameters()] |
| 111 | if epoch < train_text_encoder_epochs: | 111 | if epoch < train_text_encoder_epochs: |
| 112 | params_to_clip.append(text_encoder.parameters()) | 112 | params_to_clip.append(text_encoder.parameters()) |
