diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 344b412..c1c0eed 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -602,7 +602,7 @@ def main(): | |||
602 | elif args.mixed_precision == "bf16": | 602 | elif args.mixed_precision == "bf16": |
603 | weight_dtype = torch.bfloat16 | 603 | weight_dtype = torch.bfloat16 |
604 | 604 | ||
605 | logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) | 605 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
606 | 606 | ||
607 | if args.seed is None: | 607 | if args.seed is None: |
608 | args.seed = torch.random.seed() >> 32 | 608 | args.seed = torch.random.seed() >> 32 |
@@ -743,7 +743,7 @@ def main(): | |||
743 | else: | 743 | else: |
744 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 744 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
745 | 745 | ||
746 | checkpoint_output_dir = output_dir/"checkpoints" | 746 | checkpoint_output_dir = output_dir / "checkpoints" |
747 | 747 | ||
748 | trainer = partial( | 748 | trainer = partial( |
749 | train, | 749 | train, |
@@ -782,11 +782,11 @@ def main(): | |||
782 | 782 | ||
783 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 783 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
784 | if len(placeholder_tokens) == 1: | 784 | if len(placeholder_tokens) == 1: |
785 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" | 785 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" |
786 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" | 786 | metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" |
787 | else: | 787 | else: |
788 | sample_output_dir = output_dir/"samples" | 788 | sample_output_dir = output_dir / "samples" |
789 | metrics_output_file = output_dir/f"lr.png" | 789 | metrics_output_file = output_dir / "lr.png" |
790 | 790 | ||
791 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 791 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
792 | tokenizer=tokenizer, | 792 | tokenizer=tokenizer, |