From 37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 11:31:21 +0200 Subject: Run PTI only if placeholder tokens arg isn't empty --- train_ti.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 344b412..c1c0eed 100644 --- a/train_ti.py +++ b/train_ti.py @@ -602,7 +602,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) + logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 @@ -743,7 +743,7 @@ def main(): else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") - checkpoint_output_dir = output_dir/"checkpoints" + checkpoint_output_dir = output_dir / "checkpoints" trainer = partial( train, @@ -782,11 +782,11 @@ def main(): def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): if len(placeholder_tokens) == 1: - sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" - metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" + sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" + metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" else: - sample_output_dir = output_dir/"samples" - metrics_output_file = output_dir/f"lr.png" + sample_output_dir = output_dir / "samples" + metrics_output_file = output_dir / "lr.png" placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, -- cgit v1.2.3-54-g00ecf