diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 56f9e97..2840def 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -513,6 +513,12 @@ def main(): | |||
| 513 | mixed_precision=args.mixed_precision | 513 | mixed_precision=args.mixed_precision |
| 514 | ) | 514 | ) |
| 515 | 515 | ||
| 516 | weight_dtype = torch.float32 | ||
| 517 | if args.mixed_precision == "fp16": | ||
| 518 | weight_dtype = torch.float16 | ||
| 519 | elif args.mixed_precision == "bf16": | ||
| 520 | weight_dtype = torch.bfloat16 | ||
| 521 | |||
| 516 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 522 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
| 517 | 523 | ||
| 518 | if args.seed is None: | 524 | if args.seed is None: |
| @@ -564,12 +570,6 @@ def main(): | |||
| 564 | else: | 570 | else: |
| 565 | optimizer_class = torch.optim.AdamW | 571 | optimizer_class = torch.optim.AdamW |
| 566 | 572 | ||
| 567 | weight_dtype = torch.float32 | ||
| 568 | if args.mixed_precision == "fp16": | ||
| 569 | weight_dtype = torch.float16 | ||
| 570 | elif args.mixed_precision == "bf16": | ||
| 571 | weight_dtype = torch.bfloat16 | ||
| 572 | |||
| 573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | 573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") |
| 574 | 574 | ||
| 575 | trainer = partial( | 575 | trainer = partial( |
