summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-23 11:07:57 +0100
committerVolpeon <git@volpeon.ink>2023-03-23 11:07:57 +0100
commit0767c7bc82645186159965c2a6be4278e33c6721 (patch)
treea136470ab85dbb99ab51d9be4a7831fe21612ab3 /train_dreambooth.py
parentFix (diff)
downloadtextual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.gz
textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.bz2
textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.zip
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index f8f6e84..a85ae4c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -582,12 +582,15 @@ def main():
582 ) 582 )
583 datamodule.setup() 583 datamodule.setup()
584 584
585 optimizer = create_optimizer( 585 params_to_optimize = (unet.parameters(), )
586 itertools.chain( 586 if args.train_text_encoder_epochs != 0:
587 unet.parameters(), 587 params_to_optimize += (
588 text_encoder.text_model.encoder.parameters(), 588 text_encoder.text_model.encoder.parameters(),
589 text_encoder.text_model.final_layer_norm.parameters(), 589 text_encoder.text_model.final_layer_norm.parameters(),
590 ), 590 )
591
592 optimizer = create_optimizer(
593 itertools.chain(*params_to_optimize),
591 lr=args.learning_rate, 594 lr=args.learning_rate,
592 ) 595 )
593 596