summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index 56f9e97..2840def 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -513,6 +513,12 @@ def main():
513 mixed_precision=args.mixed_precision 513 mixed_precision=args.mixed_precision
514 ) 514 )
515 515
516 weight_dtype = torch.float32
517 if args.mixed_precision == "fp16":
518 weight_dtype = torch.float16
519 elif args.mixed_precision == "bf16":
520 weight_dtype = torch.bfloat16
521
516 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 522 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
517 523
518 if args.seed is None: 524 if args.seed is None:
@@ -564,12 +570,6 @@ def main():
564 else: 570 else:
565 optimizer_class = torch.optim.AdamW 571 optimizer_class = torch.optim.AdamW
566 572
567 weight_dtype = torch.float32
568 if args.mixed_precision == "fp16":
569 weight_dtype = torch.float16
570 elif args.mixed_precision == "bf16":
571 weight_dtype = torch.bfloat16
572
573 checkpoint_output_dir = output_dir.joinpath("checkpoints") 573 checkpoint_output_dir = output_dir.joinpath("checkpoints")
574 574
575 trainer = partial( 575 trainer = partial(