summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-07 11:31:21 +0200
committerVolpeon <git@volpeon.ink>2023-04-07 11:31:21 +0200
commit37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f (patch)
tree1f18d01cc23418789b6b4b00b38edc0a80b6214a /train_ti.py
parentFix (diff)
downloadtextual-inversion-diff-37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f.tar.gz
textual-inversion-diff-37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f.tar.bz2
textual-inversion-diff-37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f.zip
Run PTI only if placeholder tokens arg isn't empty
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index 344b412..c1c0eed 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -602,7 +602,7 @@ def main():
602 elif args.mixed_precision == "bf16": 602 elif args.mixed_precision == "bf16":
603 weight_dtype = torch.bfloat16 603 weight_dtype = torch.bfloat16
604 604
605 logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) 605 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
606 606
607 if args.seed is None: 607 if args.seed is None:
608 args.seed = torch.random.seed() >> 32 608 args.seed = torch.random.seed() >> 32
@@ -743,7 +743,7 @@ def main():
743 else: 743 else:
744 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 744 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
745 745
746 checkpoint_output_dir = output_dir/"checkpoints" 746 checkpoint_output_dir = output_dir / "checkpoints"
747 747
748 trainer = partial( 748 trainer = partial(
749 train, 749 train,
@@ -782,11 +782,11 @@ def main():
782 782
783 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): 783 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
784 if len(placeholder_tokens) == 1: 784 if len(placeholder_tokens) == 1:
785 sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" 785 sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}"
786 metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" 786 metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png"
787 else: 787 else:
788 sample_output_dir = output_dir/"samples" 788 sample_output_dir = output_dir / "samples"
789 metrics_output_file = output_dir/f"lr.png" 789 metrics_output_file = output_dir / "lr.png"
790 790
791 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 791 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
792 tokenizer=tokenizer, 792 tokenizer=tokenizer,