summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py
index 8fc2d69..cf73645 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -662,9 +662,13 @@ def main():
662 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 662 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
663 663
664 optimizer = create_optimizer( 664 optimizer = create_optimizer(
665 itertools.chain( 665 (
666 unet.parameters(), 666 param
667 text_encoder.parameters(), 667 for param in itertools.chain(
668 unet.parameters(),
669 text_encoder.parameters(),
670 )
671 if param.requires_grad
668 ), 672 ),
669 lr=args.learning_rate, 673 lr=args.learning_rate,
670 ) 674 )