summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py
index e71b7f0..17107d0 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -653,10 +653,10 @@ def main():
653 else: 653 else:
654 optimizer_class = torch.optim.AdamW 654 optimizer_class = torch.optim.AdamW
655 655
656 if args.initializer_token is not None: 656 if args.train_text_encoder:
657 text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters()
658 else:
659 text_encoder_params_to_optimize = text_encoder.parameters() 657 text_encoder_params_to_optimize = text_encoder.parameters()
658 else:
659 text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters()
660 660
661 # Initialize the optimizer 661 # Initialize the optimizer
662 optimizer = optimizer_class( 662 optimizer = optimizer_class(
@@ -945,9 +945,9 @@ def main():
945 945
946 if accelerator.sync_gradients: 946 if accelerator.sync_gradients:
947 params_to_clip = ( 947 params_to_clip = (
948 unet.parameters() 948 itertools.chain(unet.parameters(), text_encoder.parameters())
949 if args.initializer_token is not None 949 if args.train_text_encoder
950 else itertools.chain(unet.parameters(), text_encoder.parameters()) 950 else unet.parameters()
951 ) 951 )
952 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 952 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
953 953