summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 89f4113..1d0cb6f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -691,7 +691,7 @@ def main():
691 placeholder_tokens=alias_placeholder_tokens, 691 placeholder_tokens=alias_placeholder_tokens,
692 initializer_tokens=alias_initializer_tokens, 692 initializer_tokens=alias_initializer_tokens,
693 ) 693 )
694 embeddings.persist() 694 embeddings.persist(True)
695 print( 695 print(
696 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" 696 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
697 ) 697 )
@@ -712,7 +712,7 @@ def main():
712 args.placeholder_tokens = added_tokens 712 args.placeholder_tokens = added_tokens
713 print("Training embeddings from embeddings dir") 713 print("Training embeddings from embeddings dir")
714 else: 714 else:
715 embeddings.persist() 715 embeddings.persist(True)
716 716
717 if args.scale_lr: 717 if args.scale_lr:
718 args.learning_rate = ( 718 args.learning_rate = (
@@ -1067,7 +1067,7 @@ def main():
1067 args.train_data_template, 1067 args.train_data_template,
1068 ): 1068 ):
1069 run(i, [placeholder_token], [initializer_token], num_vectors, data_template) 1069 run(i, [placeholder_token], [initializer_token], num_vectors, data_template)
1070 embeddings.persist() 1070 embeddings.persist(True)
1071 1071
1072 1072
1073if __name__ == "__main__": 1073if __name__ == "__main__":