diff options
author | Volpeon <git@volpeon.ink> | 2023-01-16 21:47:06 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-16 21:47:06 +0100 |
commit | 2dfd1790078753f19ca8c585ac77079f3114f3a9 (patch) | |
tree | d1d1d643f247767c13535105dbe4afafcc5ab8c0 /train_ti.py | |
parent | If valid set size is 0, re-use one image from train set (diff) | |
download | textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.gz textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.bz2 textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.zip |
Training update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index e696577..e7aeb23 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -572,6 +572,8 @@ def main(): | |||
572 | callbacks_fn=textual_inversion_strategy | 572 | callbacks_fn=textual_inversion_strategy |
573 | ) | 573 | ) |
574 | 574 | ||
575 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | ||
576 | |||
575 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 577 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
576 | range(len(args.placeholder_tokens)), | 578 | range(len(args.placeholder_tokens)), |
577 | args.placeholder_tokens, | 579 | args.placeholder_tokens, |
@@ -579,8 +581,7 @@ def main(): | |||
579 | args.num_vectors, | 581 | args.num_vectors, |
580 | args.train_data_template | 582 | args.train_data_template |
581 | ): | 583 | ): |
582 | cur_subdir = output_dir.joinpath(placeholder_token) | 584 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") |
583 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
584 | 585 | ||
585 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 586 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
586 | tokenizer=tokenizer, | 587 | tokenizer=tokenizer, |
@@ -655,7 +656,8 @@ def main(): | |||
655 | # -- | 656 | # -- |
656 | tokenizer=tokenizer, | 657 | tokenizer=tokenizer, |
657 | sample_scheduler=sample_scheduler, | 658 | sample_scheduler=sample_scheduler, |
658 | output_dir=cur_subdir, | 659 | sample_output_dir=sample_output_dir, |
660 | checkpoint_output_dir=checkpoint_output_dir, | ||
659 | placeholder_tokens=[placeholder_token], | 661 | placeholder_tokens=[placeholder_token], |
660 | placeholder_token_ids=placeholder_token_ids, | 662 | placeholder_token_ids=placeholder_token_ids, |
661 | learning_rate=args.learning_rate, | 663 | learning_rate=args.learning_rate, |