summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-03 22:25:20 +0200
committerVolpeon <git@volpeon.ink>2023-04-03 22:25:20 +0200
commit2e654c017780d37f3304436e2feb84b619f1c023 (patch)
tree8a248fe17c3512110de9fcfed7f7bfd708b3b8da /train_ti.py
parentTI: Delta learning (diff)
downloadtextual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.tar.gz
textual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.tar.bz2
textual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.zip
Improved sparse embeddings
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 0ad7574..a9a2333 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -809,7 +809,7 @@ def main():
809 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 809 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
810 810
811 optimizer = create_optimizer( 811 optimizer = create_optimizer(
812 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 812 text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
813 lr=args.learning_rate, 813 lr=args.learning_rate,
814 ) 814 )
815 815