diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 0ad7574..a9a2333 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -809,7 +809,7 @@ def main(): | |||
809 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 809 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
810 | 810 | ||
811 | optimizer = create_optimizer( | 811 | optimizer = create_optimizer( |
812 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 812 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
813 | lr=args.learning_rate, | 813 | lr=args.learning_rate, |
814 | ) | 814 | ) |
815 | 815 | ||