diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-09 09:13:24 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-09 09:13:24 +0200 |
| commit | 810e9b3efeb99e76170486bdbb0f33a67e265dee (patch) | |
| tree | 5f68df15f78a748142a835a08ef635e20fa67b03 | |
| parent | Update (diff) | |
| download | textual-inversion-diff-810e9b3efeb99e76170486bdbb0f33a67e265dee.tar.gz textual-inversion-diff-810e9b3efeb99e76170486bdbb0f33a67e265dee.tar.bz2 textual-inversion-diff-810e9b3efeb99e76170486bdbb0f33a67e265dee.zip | |
Made Lora script interactive
| -rw-r--r-- | train_lora.py | 111 |
1 files changed, 71 insertions, 40 deletions
diff --git a/train_lora.py b/train_lora.py index d8a4880..f1e7ec7 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -159,6 +159,12 @@ def parse_args(): | |||
| 159 | help="Tag dropout probability.", | 159 | help="Tag dropout probability.", |
| 160 | ) | 160 | ) |
| 161 | parser.add_argument( | 161 | parser.add_argument( |
| 162 | "--pti_tag_dropout", | ||
| 163 | type=float, | ||
| 164 | default=0, | ||
| 165 | help="Tag dropout probability.", | ||
| 166 | ) | ||
| 167 | parser.add_argument( | ||
| 162 | "--no_tag_shuffle", | 168 | "--no_tag_shuffle", |
| 163 | action="store_true", | 169 | action="store_true", |
| 164 | help="Shuffle tags.", | 170 | help="Shuffle tags.", |
| @@ -891,7 +897,6 @@ def main(): | |||
| 891 | progressive_buckets=args.progressive_buckets, | 897 | progressive_buckets=args.progressive_buckets, |
| 892 | bucket_step_size=args.bucket_step_size, | 898 | bucket_step_size=args.bucket_step_size, |
| 893 | bucket_max_pixels=args.bucket_max_pixels, | 899 | bucket_max_pixels=args.bucket_max_pixels, |
| 894 | dropout=args.tag_dropout, | ||
| 895 | shuffle=not args.no_tag_shuffle, | 900 | shuffle=not args.no_tag_shuffle, |
| 896 | template_key=args.train_data_template, | 901 | template_key=args.train_data_template, |
| 897 | valid_set_size=args.valid_set_size, | 902 | valid_set_size=args.valid_set_size, |
| @@ -919,12 +924,9 @@ def main(): | |||
| 919 | # -------------------------------------------------------------------------------- | 924 | # -------------------------------------------------------------------------------- |
| 920 | 925 | ||
| 921 | if len(args.placeholder_tokens) != 0: | 926 | if len(args.placeholder_tokens) != 0: |
| 922 | pti_output_dir = output_dir / "pti" | ||
| 923 | pti_checkpoint_output_dir = pti_output_dir / "model" | ||
| 924 | pti_sample_output_dir = pti_output_dir / "samples" | ||
| 925 | |||
| 926 | pti_datamodule = create_datamodule( | 927 | pti_datamodule = create_datamodule( |
| 927 | batch_size=args.pti_batch_size, | 928 | batch_size=args.pti_batch_size, |
| 929 | dropout=args.pti_tag_dropout, | ||
| 928 | filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), | 930 | filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), |
| 929 | ) | 931 | ) |
| 930 | pti_datamodule.setup() | 932 | pti_datamodule.setup() |
| @@ -955,22 +957,38 @@ def main(): | |||
| 955 | train_epochs=num_pti_epochs, | 957 | train_epochs=num_pti_epochs, |
| 956 | ) | 958 | ) |
| 957 | 959 | ||
| 958 | trainer( | 960 | continue_training = True |
| 959 | strategy=lora_strategy, | 961 | training_iter = 1 |
| 960 | pti_mode=True, | 962 | |
| 961 | project="pti", | 963 | while continue_training: |
| 962 | train_dataloader=pti_datamodule.train_dataloader, | 964 | print("") |
| 963 | val_dataloader=pti_datamodule.val_dataloader, | 965 | print(f"============ PTI cycle {training_iter} ============") |
| 964 | optimizer=pti_optimizer, | 966 | print("") |
| 965 | lr_scheduler=pti_lr_scheduler, | 967 | |
| 966 | num_train_epochs=num_pti_epochs, | 968 | pti_output_dir = output_dir / f"pti_{training_iter}" |
| 967 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 969 | pti_checkpoint_output_dir = pti_output_dir / "model" |
| 968 | # -- | 970 | pti_sample_output_dir = pti_output_dir / "samples" |
| 969 | group_labels=["emb"], | 971 | |
| 970 | sample_output_dir=pti_sample_output_dir, | 972 | trainer( |
| 971 | checkpoint_output_dir=pti_checkpoint_output_dir, | 973 | strategy=lora_strategy, |
| 972 | sample_frequency=pti_sample_frequency, | 974 | pti_mode=True, |
| 973 | ) | 975 | project="pti", |
| 976 | train_dataloader=pti_datamodule.train_dataloader, | ||
| 977 | val_dataloader=pti_datamodule.val_dataloader, | ||
| 978 | optimizer=pti_optimizer, | ||
| 979 | lr_scheduler=pti_lr_scheduler, | ||
| 980 | num_train_epochs=num_pti_epochs, | ||
| 981 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
| 982 | # -- | ||
| 983 | group_labels=["emb"], | ||
| 984 | sample_output_dir=pti_sample_output_dir, | ||
| 985 | checkpoint_output_dir=pti_checkpoint_output_dir, | ||
| 986 | sample_frequency=pti_sample_frequency, | ||
| 987 | ) | ||
| 988 | |||
| 989 | response = input("Run another cycle? [y/n] ") | ||
| 990 | continue_training = response.lower().strip() != "n" | ||
| 991 | training_iter += 1 | ||
| 974 | 992 | ||
| 975 | if not args.train_emb: | 993 | if not args.train_emb: |
| 976 | embeddings.persist() | 994 | embeddings.persist() |
| @@ -978,12 +996,9 @@ def main(): | |||
| 978 | # LORA | 996 | # LORA |
| 979 | # -------------------------------------------------------------------------------- | 997 | # -------------------------------------------------------------------------------- |
| 980 | 998 | ||
| 981 | lora_output_dir = output_dir / "lora" | ||
| 982 | lora_checkpoint_output_dir = lora_output_dir / "model" | ||
| 983 | lora_sample_output_dir = lora_output_dir / "samples" | ||
| 984 | |||
| 985 | lora_datamodule = create_datamodule( | 999 | lora_datamodule = create_datamodule( |
| 986 | batch_size=args.train_batch_size, | 1000 | batch_size=args.train_batch_size, |
| 1001 | dropout=args.tag_dropout, | ||
| 987 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 1002 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
| 988 | ) | 1003 | ) |
| 989 | lora_datamodule.setup() | 1004 | lora_datamodule.setup() |
| @@ -1037,21 +1052,37 @@ def main(): | |||
| 1037 | train_epochs=num_train_epochs, | 1052 | train_epochs=num_train_epochs, |
| 1038 | ) | 1053 | ) |
| 1039 | 1054 | ||
| 1040 | trainer( | 1055 | continue_training = True |
| 1041 | strategy=lora_strategy, | 1056 | training_iter = 1 |
| 1042 | project="lora", | 1057 | |
| 1043 | train_dataloader=lora_datamodule.train_dataloader, | 1058 | while continue_training: |
| 1044 | val_dataloader=lora_datamodule.val_dataloader, | 1059 | print("") |
| 1045 | optimizer=lora_optimizer, | 1060 | print(f"============ LoRA cycle {training_iter} ============") |
| 1046 | lr_scheduler=lora_lr_scheduler, | 1061 | print("") |
| 1047 | num_train_epochs=num_train_epochs, | 1062 | |
| 1048 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1063 | lora_output_dir = output_dir / f"lora_{training_iter}" |
| 1049 | # -- | 1064 | lora_checkpoint_output_dir = lora_output_dir / "model" |
| 1050 | group_labels=group_labels, | 1065 | lora_sample_output_dir = lora_output_dir / "samples" |
| 1051 | sample_output_dir=lora_sample_output_dir, | 1066 | |
| 1052 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1067 | trainer( |
| 1053 | sample_frequency=lora_sample_frequency, | 1068 | strategy=lora_strategy, |
| 1054 | ) | 1069 | project=f"lora_{training_iter}", |
| 1070 | train_dataloader=lora_datamodule.train_dataloader, | ||
| 1071 | val_dataloader=lora_datamodule.val_dataloader, | ||
| 1072 | optimizer=lora_optimizer, | ||
| 1073 | lr_scheduler=lora_lr_scheduler, | ||
| 1074 | num_train_epochs=num_train_epochs, | ||
| 1075 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 1076 | # -- | ||
| 1077 | group_labels=group_labels, | ||
| 1078 | sample_output_dir=lora_sample_output_dir, | ||
| 1079 | checkpoint_output_dir=lora_checkpoint_output_dir, | ||
| 1080 | sample_frequency=lora_sample_frequency, | ||
| 1081 | ) | ||
| 1082 | |||
| 1083 | response = input("Run another cycle? [y/n] ") | ||
| 1084 | continue_training = response.lower().strip() != "n" | ||
| 1085 | training_iter += 1 | ||
| 1055 | 1086 | ||
| 1056 | 1087 | ||
| 1057 | if __name__ == "__main__": | 1088 | if __name__ == "__main__": |
