From 0767c7bc82645186159965c2a6be4278e33c6721 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 23 Mar 2023 11:07:57 +0100 Subject: Update --- train_dreambooth.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index f8f6e84..a85ae4c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -582,12 +582,15 @@ def main(): ) datamodule.setup() - optimizer = create_optimizer( - itertools.chain( - unet.parameters(), + params_to_optimize = (unet.parameters(), ) + if args.train_text_encoder_epochs != 0: + params_to_optimize += ( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), - ), + ) + + optimizer = create_optimizer( + itertools.chain(*params_to_optimize), lr=args.learning_rate, ) -- cgit v1.2.3-54-g00ecf