diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 89f4113..1d0cb6f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -691,7 +691,7 @@ def main(): | |||
691 | placeholder_tokens=alias_placeholder_tokens, | 691 | placeholder_tokens=alias_placeholder_tokens, |
692 | initializer_tokens=alias_initializer_tokens, | 692 | initializer_tokens=alias_initializer_tokens, |
693 | ) | 693 | ) |
694 | embeddings.persist() | 694 | embeddings.persist(True) |
695 | print( | 695 | print( |
696 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | 696 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" |
697 | ) | 697 | ) |
@@ -712,7 +712,7 @@ def main(): | |||
712 | args.placeholder_tokens = added_tokens | 712 | args.placeholder_tokens = added_tokens |
713 | print("Training embeddings from embeddings dir") | 713 | print("Training embeddings from embeddings dir") |
714 | else: | 714 | else: |
715 | embeddings.persist() | 715 | embeddings.persist(True) |
716 | 716 | ||
717 | if args.scale_lr: | 717 | if args.scale_lr: |
718 | args.learning_rate = ( | 718 | args.learning_rate = ( |
@@ -1067,7 +1067,7 @@ def main(): | |||
1067 | args.train_data_template, | 1067 | args.train_data_template, |
1068 | ): | 1068 | ): |
1069 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) | 1069 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) |
1070 | embeddings.persist() | 1070 | embeddings.persist(True) |
1071 | 1071 | ||
1072 | 1072 | ||
1073 | if __name__ == "__main__": | 1073 | if __name__ == "__main__": |