diff options
| -rw-r--r-- | dreambooth.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index e71b7f0..17107d0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -653,10 +653,10 @@ def main(): | |||
| 653 | else: | 653 | else: |
| 654 | optimizer_class = torch.optim.AdamW | 654 | optimizer_class = torch.optim.AdamW |
| 655 | 655 | ||
| 656 | if args.initializer_token is not None: | 656 | if args.train_text_encoder: |
| 657 | text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() | ||
| 658 | else: | ||
| 659 | text_encoder_params_to_optimize = text_encoder.parameters() | 657 | text_encoder_params_to_optimize = text_encoder.parameters() |
| 658 | else: | ||
| 659 | text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() | ||
| 660 | 660 | ||
| 661 | # Initialize the optimizer | 661 | # Initialize the optimizer |
| 662 | optimizer = optimizer_class( | 662 | optimizer = optimizer_class( |
| @@ -945,9 +945,9 @@ def main(): | |||
| 945 | 945 | ||
| 946 | if accelerator.sync_gradients: | 946 | if accelerator.sync_gradients: |
| 947 | params_to_clip = ( | 947 | params_to_clip = ( |
| 948 | unet.parameters() | 948 | itertools.chain(unet.parameters(), text_encoder.parameters()) |
| 949 | if args.initializer_token is not None | 949 | if args.train_text_encoder |
| 950 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | 950 | else unet.parameters() |
| 951 | ) | 951 | ) |
| 952 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | 952 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
| 953 | 953 | ||
