From 9336eb51528a689297453ca8a0414e463fca4184 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Oct 2022 18:14:33 +0200 Subject: Fix --- dreambooth.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index e71b7f0..17107d0 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -653,10 +653,10 @@ def main(): else: optimizer_class = torch.optim.AdamW - if args.initializer_token is not None: - text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() - else: + if args.train_text_encoder: text_encoder_params_to_optimize = text_encoder.parameters() + else: + text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() # Initialize the optimizer optimizer = optimizer_class( @@ -945,9 +945,9 @@ def main(): if accelerator.sync_gradients: params_to_clip = ( - unet.parameters() - if args.initializer_token is not None - else itertools.chain(unet.parameters(), text_encoder.parameters()) + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) -- cgit v1.2.3-54-g00ecf