diff options
author | Volpeon <git@volpeon.ink> | 2023-03-23 11:07:57 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-23 11:07:57 +0100 |
commit | 0767c7bc82645186159965c2a6be4278e33c6721 (patch) | |
tree | a136470ab85dbb99ab51d9be4a7831fe21612ab3 /train_dreambooth.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.gz textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.bz2 textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.zip |
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index f8f6e84..a85ae4c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -582,12 +582,15 @@ def main(): | |||
582 | ) | 582 | ) |
583 | datamodule.setup() | 583 | datamodule.setup() |
584 | 584 | ||
585 | optimizer = create_optimizer( | 585 | params_to_optimize = (unet.parameters(), ) |
586 | itertools.chain( | 586 | if args.train_text_encoder_epochs != 0: |
587 | unet.parameters(), | 587 | params_to_optimize += ( |
588 | text_encoder.text_model.encoder.parameters(), | 588 | text_encoder.text_model.encoder.parameters(), |
589 | text_encoder.text_model.final_layer_norm.parameters(), | 589 | text_encoder.text_model.final_layer_norm.parameters(), |
590 | ), | 590 | ) |
591 | |||
592 | optimizer = create_optimizer( | ||
593 | itertools.chain(*params_to_optimize), | ||
591 | lr=args.learning_rate, | 594 | lr=args.learning_rate, |
592 | ) | 595 | ) |
593 | 596 | ||