diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-23 11:07:57 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-23 11:07:57 +0100 |
| commit | 0767c7bc82645186159965c2a6be4278e33c6721 (patch) | |
| tree | a136470ab85dbb99ab51d9be4a7831fe21612ab3 /train_dreambooth.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.gz textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.bz2 textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.zip | |
Update
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index f8f6e84..a85ae4c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -582,12 +582,15 @@ def main(): | |||
| 582 | ) | 582 | ) |
| 583 | datamodule.setup() | 583 | datamodule.setup() |
| 584 | 584 | ||
| 585 | optimizer = create_optimizer( | 585 | params_to_optimize = (unet.parameters(), ) |
| 586 | itertools.chain( | 586 | if args.train_text_encoder_epochs != 0: |
| 587 | unet.parameters(), | 587 | params_to_optimize += ( |
| 588 | text_encoder.text_model.encoder.parameters(), | 588 | text_encoder.text_model.encoder.parameters(), |
| 589 | text_encoder.text_model.final_layer_norm.parameters(), | 589 | text_encoder.text_model.final_layer_norm.parameters(), |
| 590 | ), | 590 | ) |
| 591 | |||
| 592 | optimizer = create_optimizer( | ||
| 593 | itertools.chain(*params_to_optimize), | ||
| 591 | lr=args.learning_rate, | 594 | lr=args.learning_rate, |
| 592 | ) | 595 | ) |
| 593 | 596 | ||
