From 2dfd1790078753f19ca8c585ac77079f3114f3a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 21:47:06 +0100 Subject: Training update --- train_ti.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index e696577..e7aeb23 100644 --- a/train_ti.py +++ b/train_ti.py @@ -572,6 +572,8 @@ def main(): callbacks_fn=textual_inversion_strategy ) + checkpoint_output_dir = output_dir.joinpath("checkpoints") + for i, placeholder_token, initializer_token, num_vectors, data_template in zip( range(len(args.placeholder_tokens)), args.placeholder_tokens, @@ -579,8 +581,7 @@ def main(): args.num_vectors, args.train_data_template ): - cur_subdir = output_dir.joinpath(placeholder_token) - cur_subdir.mkdir(parents=True, exist_ok=True) + sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, @@ -655,7 +656,8 @@ def main(): # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, - output_dir=cur_subdir, + sample_output_dir=sample_output_dir, + checkpoint_output_dir=checkpoint_output_dir, placeholder_tokens=[placeholder_token], placeholder_token_ids=placeholder_token_ids, learning_rate=args.learning_rate, -- cgit v1.2.3-54-g00ecf