diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-08 17:38:49 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-08 17:38:49 +0200 |
| commit | 9f5f70cb2a8919cb07821f264bf0fd75bfa10584 (patch) | |
| tree | 19bd8802b6cfd941797beabfc0bb2595ffb00b5f /train_lora.py | |
| parent | Fix TI (diff) | |
| download | textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.gz textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.bz2 textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.zip | |
Update
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 32 |
1 files changed, 22 insertions, 10 deletions
diff --git a/train_lora.py b/train_lora.py index 1626be6..e4b5546 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -93,6 +93,12 @@ def parse_args(): | |||
| 93 | help="A token to use as initializer word." | 93 | help="A token to use as initializer word." |
| 94 | ) | 94 | ) |
| 95 | parser.add_argument( | 95 | parser.add_argument( |
| 96 | "--filter_tokens", | ||
| 97 | type=str, | ||
| 98 | nargs='*', | ||
| 99 | help="Tokens to filter the dataset by." | ||
| 100 | ) | ||
| 101 | parser.add_argument( | ||
| 96 | "--initializer_noise", | 102 | "--initializer_noise", |
| 97 | type=float, | 103 | type=float, |
| 98 | default=0, | 104 | default=0, |
| @@ -592,6 +598,12 @@ def parse_args(): | |||
| 592 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 598 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: |
| 593 | raise ValueError("--alias_tokens must be a list with an even number of items") | 599 | raise ValueError("--alias_tokens must be a list with an even number of items") |
| 594 | 600 | ||
| 601 | if args.filter_tokens is None: | ||
| 602 | args.filter_tokens = args.placeholder_tokens.copy() | ||
| 603 | |||
| 604 | if isinstance(args.filter_tokens, str): | ||
| 605 | args.filter_tokens = [args.filter_tokens] | ||
| 606 | |||
| 595 | if isinstance(args.collection, str): | 607 | if isinstance(args.collection, str): |
| 596 | args.collection = [args.collection] | 608 | args.collection = [args.collection] |
| 597 | 609 | ||
| @@ -890,7 +902,7 @@ def main(): | |||
| 890 | 902 | ||
| 891 | pti_datamodule = create_datamodule( | 903 | pti_datamodule = create_datamodule( |
| 892 | batch_size=args.pti_batch_size, | 904 | batch_size=args.pti_batch_size, |
| 893 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 905 | filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), |
| 894 | ) | 906 | ) |
| 895 | pti_datamodule.setup() | 907 | pti_datamodule.setup() |
| 896 | 908 | ||
| @@ -906,7 +918,7 @@ def main(): | |||
| 906 | pti_optimizer = create_optimizer( | 918 | pti_optimizer = create_optimizer( |
| 907 | [ | 919 | [ |
| 908 | { | 920 | { |
| 909 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 921 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 910 | "lr": args.learning_rate_pti, | 922 | "lr": args.learning_rate_pti, |
| 911 | "weight_decay": 0, | 923 | "weight_decay": 0, |
| 912 | }, | 924 | }, |
| @@ -937,7 +949,7 @@ def main(): | |||
| 937 | sample_frequency=pti_sample_frequency, | 949 | sample_frequency=pti_sample_frequency, |
| 938 | ) | 950 | ) |
| 939 | 951 | ||
| 940 | # embeddings.persist() | 952 | embeddings.persist() |
| 941 | 953 | ||
| 942 | # LORA | 954 | # LORA |
| 943 | # -------------------------------------------------------------------------------- | 955 | # -------------------------------------------------------------------------------- |
| @@ -962,13 +974,13 @@ def main(): | |||
| 962 | 974 | ||
| 963 | params_to_optimize = [] | 975 | params_to_optimize = [] |
| 964 | group_labels = [] | 976 | group_labels = [] |
| 965 | if len(args.placeholder_tokens) != 0: | 977 | # if len(args.placeholder_tokens) != 0: |
| 966 | params_to_optimize.append({ | 978 | # params_to_optimize.append({ |
| 967 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 979 | # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 968 | "lr": args.learning_rate_text, | 980 | # "lr": args.learning_rate_text, |
| 969 | "weight_decay": 0, | 981 | # "weight_decay": 0, |
| 970 | }) | 982 | # }) |
| 971 | group_labels.append("emb") | 983 | # group_labels.append("emb") |
| 972 | params_to_optimize += [ | 984 | params_to_optimize += [ |
| 973 | { | 985 | { |
| 974 | "params": ( | 986 | "params": ( |
