summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py
index 4366c9e..6757bde 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -204,7 +204,8 @@ def parse_args():
204 "--vector_shuffle", 204 "--vector_shuffle",
205 type=str, 205 type=str,
206 default="auto", 206 default="auto",
207 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', 207 choices=["all", "trailing", "leading", "between", "auto", "off"],
208 help='Vector shuffling algorithm.',
208 ) 209 )
209 parser.add_argument( 210 parser.add_argument(
210 "--offset_noise_strength", 211 "--offset_noise_strength",
@@ -253,10 +254,9 @@ def parse_args():
253 "--lr_scheduler", 254 "--lr_scheduler",
254 type=str, 255 type=str,
255 default="one_cycle", 256 default="one_cycle",
256 help=( 257 choices=["linear", "cosine", "cosine_with_restarts", "polynomial",
257 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 258 "constant", "constant_with_warmup", "one_cycle"],
258 ' "constant", "constant_with_warmup", "one_cycle"]' 259 help='The scheduler type to use.',
259 ),
260 ) 260 )
261 parser.add_argument( 261 parser.add_argument(
262 "--lr_warmup_epochs", 262 "--lr_warmup_epochs",
@@ -280,7 +280,7 @@ def parse_args():
280 "--lr_warmup_func", 280 "--lr_warmup_func",
281 type=str, 281 type=str,
282 default="cos", 282 default="cos",
283 help='Choose between ["linear", "cos"]' 283 choices=["linear", "cos"],
284 ) 284 )
285 parser.add_argument( 285 parser.add_argument(
286 "--lr_warmup_exp", 286 "--lr_warmup_exp",
@@ -292,7 +292,7 @@ def parse_args():
292 "--lr_annealing_func", 292 "--lr_annealing_func",
293 type=str, 293 type=str,
294 default="cos", 294 default="cos",
295 help='Choose between ["linear", "half_cos", "cos"]' 295 choices=["linear", "half_cos", "cos"],
296 ) 296 )
297 parser.add_argument( 297 parser.add_argument(
298 "--lr_annealing_exp", 298 "--lr_annealing_exp",
@@ -330,7 +330,8 @@ def parse_args():
330 "--optimizer", 330 "--optimizer",
331 type=str, 331 type=str,
332 default="dadan", 332 default="dadan",
333 help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' 333 choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"],
334 help='Optimizer to use'
334 ) 335 )
335 parser.add_argument( 336 parser.add_argument(
336 "--dadaptation_d0", 337 "--dadaptation_d0",