summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-08 17:38:49 +0200
committerVolpeon <git@volpeon.ink>2023-04-08 17:38:49 +0200
commit9f5f70cb2a8919cb07821f264bf0fd75bfa10584 (patch)
tree19bd8802b6cfd941797beabfc0bb2595ffb00b5f /train_lora.py
parentFix TI (diff)
downloadtextual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.gz
textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.bz2
textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.zip
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py32
1 files changed, 22 insertions, 10 deletions
diff --git a/train_lora.py b/train_lora.py
index 1626be6..e4b5546 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -93,6 +93,12 @@ def parse_args():
93 help="A token to use as initializer word." 93 help="A token to use as initializer word."
94 ) 94 )
95 parser.add_argument( 95 parser.add_argument(
96 "--filter_tokens",
97 type=str,
98 nargs='*',
99 help="Tokens to filter the dataset by."
100 )
101 parser.add_argument(
96 "--initializer_noise", 102 "--initializer_noise",
97 type=float, 103 type=float,
98 default=0, 104 default=0,
@@ -592,6 +598,12 @@ def parse_args():
592 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: 598 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
593 raise ValueError("--alias_tokens must be a list with an even number of items") 599 raise ValueError("--alias_tokens must be a list with an even number of items")
594 600
601 if args.filter_tokens is None:
602 args.filter_tokens = args.placeholder_tokens.copy()
603
604 if isinstance(args.filter_tokens, str):
605 args.filter_tokens = [args.filter_tokens]
606
595 if isinstance(args.collection, str): 607 if isinstance(args.collection, str):
596 args.collection = [args.collection] 608 args.collection = [args.collection]
597 609
@@ -890,7 +902,7 @@ def main():
890 902
891 pti_datamodule = create_datamodule( 903 pti_datamodule = create_datamodule(
892 batch_size=args.pti_batch_size, 904 batch_size=args.pti_batch_size,
893 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), 905 filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections),
894 ) 906 )
895 pti_datamodule.setup() 907 pti_datamodule.setup()
896 908
@@ -906,7 +918,7 @@ def main():
906 pti_optimizer = create_optimizer( 918 pti_optimizer = create_optimizer(
907 [ 919 [
908 { 920 {
909 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 921 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
910 "lr": args.learning_rate_pti, 922 "lr": args.learning_rate_pti,
911 "weight_decay": 0, 923 "weight_decay": 0,
912 }, 924 },
@@ -937,7 +949,7 @@ def main():
937 sample_frequency=pti_sample_frequency, 949 sample_frequency=pti_sample_frequency,
938 ) 950 )
939 951
940 # embeddings.persist() 952 embeddings.persist()
941 953
942 # LORA 954 # LORA
943 # -------------------------------------------------------------------------------- 955 # --------------------------------------------------------------------------------
@@ -962,13 +974,13 @@ def main():
962 974
963 params_to_optimize = [] 975 params_to_optimize = []
964 group_labels = [] 976 group_labels = []
965 if len(args.placeholder_tokens) != 0: 977 # if len(args.placeholder_tokens) != 0:
966 params_to_optimize.append({ 978 # params_to_optimize.append({
967 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 979 # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
968 "lr": args.learning_rate_text, 980 # "lr": args.learning_rate_text,
969 "weight_decay": 0, 981 # "weight_decay": 0,
970 }) 982 # })
971 group_labels.append("emb") 983 # group_labels.append("emb")
972 params_to_optimize += [ 984 params_to_optimize += [
973 { 985 {
974 "params": ( 986 "params": (