summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 21:47:06 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 21:47:06 +0100
commit2dfd1790078753f19ca8c585ac77079f3114f3a9 (patch)
treed1d1d643f247767c13535105dbe4afafcc5ab8c0 /train_ti.py
parentIf valid set size is 0, re-use one image from train set (diff)
downloadtextual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.gz
textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.bz2
textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.zip
Training update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index e696577..e7aeb23 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -572,6 +572,8 @@ def main():
572 callbacks_fn=textual_inversion_strategy 572 callbacks_fn=textual_inversion_strategy
573 ) 573 )
574 574
575 checkpoint_output_dir = output_dir.joinpath("checkpoints")
576
575 for i, placeholder_token, initializer_token, num_vectors, data_template in zip( 577 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
576 range(len(args.placeholder_tokens)), 578 range(len(args.placeholder_tokens)),
577 args.placeholder_tokens, 579 args.placeholder_tokens,
@@ -579,8 +581,7 @@ def main():
579 args.num_vectors, 581 args.num_vectors,
580 args.train_data_template 582 args.train_data_template
581 ): 583 ):
582 cur_subdir = output_dir.joinpath(placeholder_token) 584 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}")
583 cur_subdir.mkdir(parents=True, exist_ok=True)
584 585
585 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 586 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
586 tokenizer=tokenizer, 587 tokenizer=tokenizer,
@@ -655,7 +656,8 @@ def main():
655 # -- 656 # --
656 tokenizer=tokenizer, 657 tokenizer=tokenizer,
657 sample_scheduler=sample_scheduler, 658 sample_scheduler=sample_scheduler,
658 output_dir=cur_subdir, 659 sample_output_dir=sample_output_dir,
660 checkpoint_output_dir=checkpoint_output_dir,
659 placeholder_tokens=[placeholder_token], 661 placeholder_tokens=[placeholder_token],
660 placeholder_token_ids=placeholder_token_ids, 662 placeholder_token_ids=placeholder_token_ids,
661 learning_rate=args.learning_rate, 663 learning_rate=args.learning_rate,