diff options
author | Volpeon <git@volpeon.ink> | 2022-10-27 18:14:33 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-27 18:14:33 +0200 |
commit | 9336eb51528a689297453ca8a0414e463fca4184 (patch) | |
tree | 801a3e91289f83eb16749580be6a764dedb98b11 | |
parent | Added CLI arg to set dataloader worker num; improved text encoder handling wi... (diff) | |
download | textual-inversion-diff-9336eb51528a689297453ca8a0414e463fca4184.tar.gz textual-inversion-diff-9336eb51528a689297453ca8a0414e463fca4184.tar.bz2 textual-inversion-diff-9336eb51528a689297453ca8a0414e463fca4184.zip |
Fix
-rw-r--r-- | dreambooth.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index e71b7f0..17107d0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -653,10 +653,10 @@ def main(): | |||
653 | else: | 653 | else: |
654 | optimizer_class = torch.optim.AdamW | 654 | optimizer_class = torch.optim.AdamW |
655 | 655 | ||
656 | if args.initializer_token is not None: | 656 | if args.train_text_encoder: |
657 | text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() | ||
658 | else: | ||
659 | text_encoder_params_to_optimize = text_encoder.parameters() | 657 | text_encoder_params_to_optimize = text_encoder.parameters() |
658 | else: | ||
659 | text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() | ||
660 | 660 | ||
661 | # Initialize the optimizer | 661 | # Initialize the optimizer |
662 | optimizer = optimizer_class( | 662 | optimizer = optimizer_class( |
@@ -945,9 +945,9 @@ def main(): | |||
945 | 945 | ||
946 | if accelerator.sync_gradients: | 946 | if accelerator.sync_gradients: |
947 | params_to_clip = ( | 947 | params_to_clip = ( |
948 | unet.parameters() | 948 | itertools.chain(unet.parameters(), text_encoder.parameters()) |
949 | if args.initializer_token is not None | 949 | if args.train_text_encoder |
950 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | 950 | else unet.parameters() |
951 | ) | 951 | ) |
952 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | 952 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
953 | 953 | ||