summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-02 08:42:33 +0200
committerVolpeon <git@volpeon.ink>2023-04-02 08:42:33 +0200
commit11e6f8f88483e6cfdccd66ad758ae1dfcfc0283b (patch)
treee66c62abb974c01769285b1c01c748e6c49cc97b /train_lora.py
parentRevert (diff)
downloadtextual-inversion-diff-11e6f8f88483e6cfdccd66ad758ae1dfcfc0283b.tar.gz
textual-inversion-diff-11e6f8f88483e6cfdccd66ad758ae1dfcfc0283b.tar.bz2
textual-inversion-diff-11e6f8f88483e6cfdccd66ad758ae1dfcfc0283b.zip
Lora: Only register params with grad to optimizer
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py
index 8fc2d69..cf73645 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -662,9 +662,13 @@ def main():
662 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 662 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
663 663
664 optimizer = create_optimizer( 664 optimizer = create_optimizer(
665 itertools.chain( 665 (
666 unet.parameters(), 666 param
667 text_encoder.parameters(), 667 for param in itertools.chain(
668 unet.parameters(),
669 text_encoder.parameters(),
670 )
671 if param.requires_grad
668 ), 672 ),
669 lr=args.learning_rate, 673 lr=args.learning_rate,
670 ) 674 )