diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index 8fc2d69..cf73645 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -662,9 +662,13 @@ def main(): | |||
662 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 662 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
663 | 663 | ||
664 | optimizer = create_optimizer( | 664 | optimizer = create_optimizer( |
665 | itertools.chain( | 665 | ( |
666 | unet.parameters(), | 666 | param |
667 | text_encoder.parameters(), | 667 | for param in itertools.chain( |
668 | unet.parameters(), | ||
669 | text_encoder.parameters(), | ||
670 | ) | ||
671 | if param.requires_grad | ||
668 | ), | 672 | ), |
669 | lr=args.learning_rate, | 673 | lr=args.learning_rate, |
670 | ) | 674 | ) |