diff options
author | Volpeon <git@volpeon.ink> | 2023-04-02 08:42:33 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-02 08:42:33 +0200 |
commit | 11e6f8f88483e6cfdccd66ad758ae1dfcfc0283b (patch) | |
tree | e66c62abb974c01769285b1c01c748e6c49cc97b /train_lora.py | |
parent | Revert (diff) | |
download | textual-inversion-diff-11e6f8f88483e6cfdccd66ad758ae1dfcfc0283b.tar.gz textual-inversion-diff-11e6f8f88483e6cfdccd66ad758ae1dfcfc0283b.tar.bz2 textual-inversion-diff-11e6f8f88483e6cfdccd66ad758ae1dfcfc0283b.zip |
Lora: Only register params with grad to optimizer
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index 8fc2d69..cf73645 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -662,9 +662,13 @@ def main(): | |||
662 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 662 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
663 | 663 | ||
664 | optimizer = create_optimizer( | 664 | optimizer = create_optimizer( |
665 | itertools.chain( | 665 | ( |
666 | unet.parameters(), | 666 | param |
667 | text_encoder.parameters(), | 667 | for param in itertools.chain( |
668 | unet.parameters(), | ||
669 | text_encoder.parameters(), | ||
670 | ) | ||
671 | if param.requires_grad | ||
668 | ), | 672 | ), |
669 | lr=args.learning_rate, | 673 | lr=args.learning_rate, |
670 | ) | 674 | ) |