summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-23 06:48:38 +0200
committerVolpeon <git@volpeon.ink>2023-06-23 06:48:38 +0200
commit950f1f6bcbb1a767170cea590b828d8e3cdae882 (patch)
tree019a1d3463b363a3f335e16c50cb6890efc3470f /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-950f1f6bcbb1a767170cea590b828d8e3cdae882.tar.gz
textual-inversion-diff-950f1f6bcbb1a767170cea590b828d8e3cdae882.tar.bz2
textual-inversion-diff-950f1f6bcbb1a767170cea590b828d8e3cdae882.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py20
1 files changed, 1 insertions, 19 deletions
diff --git a/train_ti.py b/train_ti.py
index 1dbd637..8c63493 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -112,12 +112,6 @@ def parse_args():
112 help="Tokens to create an alias for.", 112 help="Tokens to create an alias for.",
113 ) 113 )
114 parser.add_argument( 114 parser.add_argument(
115 "--inverted_initializer_tokens",
116 type=str,
117 nargs="*",
118 help="A token to use as initializer word.",
119 )
120 parser.add_argument(
121 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." 115 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
122 ) 116 )
123 parser.add_argument( 117 parser.add_argument(
@@ -545,18 +539,6 @@ def parse_args():
545 "--placeholder_tokens and --initializer_tokens must have the same number of items" 539 "--placeholder_tokens and --initializer_tokens must have the same number of items"
546 ) 540 )
547 541
548 if isinstance(args.inverted_initializer_tokens, str):
549 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
550 args.placeholder_tokens
551 )
552
553 if (
554 isinstance(args.inverted_initializer_tokens, list)
555 and len(args.inverted_initializer_tokens) != 0
556 ):
557 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
558 args.initializer_tokens += args.inverted_initializer_tokens
559
560 if isinstance(args.num_vectors, int): 542 if isinstance(args.num_vectors, int):
561 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 543 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
562 544
@@ -872,7 +854,7 @@ def main():
872 854
873 optimizer = create_optimizer( 855 optimizer = create_optimizer(
874 text_encoder.text_model.embeddings.token_embedding.parameters(), 856 text_encoder.text_model.embeddings.token_embedding.parameters(),
875 lr=learning_rate, 857 lr=args.learning_rate,
876 ) 858 )
877 859
878 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 860 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)