diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-16 21:47:06 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-16 21:47:06 +0100 |
| commit | 2dfd1790078753f19ca8c585ac77079f3114f3a9 (patch) | |
| tree | d1d1d643f247767c13535105dbe4afafcc5ab8c0 /train_ti.py | |
| parent | If valid set size is 0, re-use one image from train set (diff) | |
| download | textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.gz textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.bz2 textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.zip | |
Training update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index e696577..e7aeb23 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -572,6 +572,8 @@ def main(): | |||
| 572 | callbacks_fn=textual_inversion_strategy | 572 | callbacks_fn=textual_inversion_strategy |
| 573 | ) | 573 | ) |
| 574 | 574 | ||
| 575 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | ||
| 576 | |||
| 575 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 577 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
| 576 | range(len(args.placeholder_tokens)), | 578 | range(len(args.placeholder_tokens)), |
| 577 | args.placeholder_tokens, | 579 | args.placeholder_tokens, |
| @@ -579,8 +581,7 @@ def main(): | |||
| 579 | args.num_vectors, | 581 | args.num_vectors, |
| 580 | args.train_data_template | 582 | args.train_data_template |
| 581 | ): | 583 | ): |
| 582 | cur_subdir = output_dir.joinpath(placeholder_token) | 584 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") |
| 583 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
| 584 | 585 | ||
| 585 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 586 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 586 | tokenizer=tokenizer, | 587 | tokenizer=tokenizer, |
| @@ -655,7 +656,8 @@ def main(): | |||
| 655 | # -- | 656 | # -- |
| 656 | tokenizer=tokenizer, | 657 | tokenizer=tokenizer, |
| 657 | sample_scheduler=sample_scheduler, | 658 | sample_scheduler=sample_scheduler, |
| 658 | output_dir=cur_subdir, | 659 | sample_output_dir=sample_output_dir, |
| 660 | checkpoint_output_dir=checkpoint_output_dir, | ||
| 659 | placeholder_tokens=[placeholder_token], | 661 | placeholder_tokens=[placeholder_token], |
| 660 | placeholder_token_ids=placeholder_token_ids, | 662 | placeholder_token_ids=placeholder_token_ids, |
| 661 | learning_rate=args.learning_rate, | 663 | learning_rate=args.learning_rate, |
