summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 7aa4960..451b61b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -159,7 +159,7 @@ def parse_args():
159 parser.add_argument( 159 parser.add_argument(
160 "--tag_dropout", 160 "--tag_dropout",
161 type=float, 161 type=float,
162 default=0.1, 162 default=0,
163 help="Tag dropout probability.", 163 help="Tag dropout probability.",
164 ) 164 )
165 parser.add_argument( 165 parser.add_argument(
@@ -407,7 +407,7 @@ def parse_args():
407 ) 407 )
408 parser.add_argument( 408 parser.add_argument(
409 "--emb_decay", 409 "--emb_decay",
410 default=1e-2, 410 default=10,
411 type=float, 411 type=float,
412 help="Embedding decay factor." 412 help="Embedding decay factor."
413 ) 413 )
@@ -597,7 +597,7 @@ def main():
597 597
598 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): 598 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
599 if len(placeholder_tokens) == 1: 599 if len(placeholder_tokens) == 1:
600 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}") 600 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}")
601 else: 601 else:
602 sample_output_dir = output_dir.joinpath("samples") 602 sample_output_dir = output_dir.joinpath("samples")
603 603