summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-08 17:38:49 +0200
committerVolpeon <git@volpeon.ink>2023-04-08 17:38:49 +0200
commit9f5f70cb2a8919cb07821f264bf0fd75bfa10584 (patch)
tree19bd8802b6cfd941797beabfc0bb2595ffb00b5f /train_ti.py
parentFix TI (diff)
downloadtextual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.gz
textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.bz2
textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index 48858cc..daf8bc5 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -3,6 +3,7 @@ import datetime
3import logging 3import logging
4from functools import partial 4from functools import partial
5from pathlib import Path 5from pathlib import Path
6from typing import Union
6import math 7import math
7 8
8import torch 9import torch
@@ -75,6 +76,12 @@ def parse_args():
75 help="A token to use as initializer word." 76 help="A token to use as initializer word."
76 ) 77 )
77 parser.add_argument( 78 parser.add_argument(
79 "--filter_tokens",
80 type=str,
81 nargs='*',
82 help="Tokens to filter the dataset by."
83 )
84 parser.add_argument(
78 "--initializer_noise", 85 "--initializer_noise",
79 type=float, 86 type=float,
80 default=0, 87 default=0,
@@ -538,6 +545,12 @@ def parse_args():
538 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: 545 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
539 raise ValueError("--alias_tokens must be a list with an even number of items") 546 raise ValueError("--alias_tokens must be a list with an even number of items")
540 547
548 if args.filter_tokens is None:
549 args.filter_tokens = args.placeholder_tokens.copy()
550
551 if isinstance(args.filter_tokens, str):
552 args.filter_tokens = [args.filter_tokens]
553
541 if args.sequential: 554 if args.sequential:
542 args.alias_tokens += [ 555 args.alias_tokens += [
543 item 556 item
@@ -779,13 +792,11 @@ def main():
779 sample_image_size=args.sample_image_size, 792 sample_image_size=args.sample_image_size,
780 ) 793 )
781 794
782 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): 795 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str):
783 if len(placeholder_tokens) == 1: 796 if len(placeholder_tokens) == 1:
784 sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" 797 sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}"
785 metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png"
786 else: 798 else:
787 sample_output_dir = output_dir / "samples" 799 sample_output_dir = output_dir / "samples"
788 metrics_output_file = output_dir / "lr.png"
789 800
790 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 801 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
791 tokenizer=tokenizer, 802 tokenizer=tokenizer,
@@ -800,6 +811,8 @@ def main():
800 811
801 print(f"{i + 1}: {stats}") 812 print(f"{i + 1}: {stats}")
802 813
814 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens]
815
803 datamodule = VlpnDataModule( 816 datamodule = VlpnDataModule(
804 data_file=args.train_data_file, 817 data_file=args.train_data_file,
805 batch_size=args.train_batch_size, 818 batch_size=args.train_batch_size,
@@ -820,7 +833,7 @@ def main():
820 train_set_pad=args.train_set_pad, 833 train_set_pad=args.train_set_pad,
821 valid_set_pad=args.valid_set_pad, 834 valid_set_pad=args.valid_set_pad,
822 seed=args.seed, 835 seed=args.seed,
823 filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), 836 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections),
824 dtype=weight_dtype 837 dtype=weight_dtype
825 ) 838 )
826 datamodule.setup() 839 datamodule.setup()
@@ -834,7 +847,7 @@ def main():
834 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 847 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
835 848
836 optimizer = create_optimizer( 849 optimizer = create_optimizer(
837 text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 850 text_encoder.text_model.embeddings.token_override_embedding.parameters(),
838 lr=args.learning_rate, 851 lr=args.learning_rate,
839 ) 852 )
840 853