diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/train_lora.py b/train_lora.py index 476efcf..5b0a292 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -21,7 +21,6 @@ from data.csv import VlpnDataModule, keyword_filter | |||
21 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
22 | from training.lr import plot_metrics | 22 | from training.lr import plot_metrics |
23 | from training.strategy.lora import lora_strategy | 23 | from training.strategy.lora import lora_strategy |
24 | from training.strategy.ti import textual_inversion_strategy | ||
25 | from training.optimization import get_scheduler | 24 | from training.optimization import get_scheduler |
26 | from training.util import save_args | 25 | from training.util import save_args |
27 | 26 | ||
@@ -829,6 +828,12 @@ def main(): | |||
829 | sample_num_batches=args.sample_batches, | 828 | sample_num_batches=args.sample_batches, |
830 | sample_num_steps=args.sample_steps, | 829 | sample_num_steps=args.sample_steps, |
831 | sample_image_size=args.sample_image_size, | 830 | sample_image_size=args.sample_image_size, |
831 | placeholder_tokens=args.placeholder_tokens, | ||
832 | placeholder_token_ids=placeholder_token_ids, | ||
833 | use_emb_decay=args.use_emb_decay, | ||
834 | emb_decay_target=args.emb_decay_target, | ||
835 | emb_decay=args.emb_decay, | ||
836 | max_grad_norm=args.max_grad_norm, | ||
832 | ) | 837 | ) |
833 | 838 | ||
834 | create_datamodule = partial( | 839 | create_datamodule = partial( |
@@ -907,7 +912,8 @@ def main(): | |||
907 | ) | 912 | ) |
908 | 913 | ||
909 | metrics = trainer( | 914 | metrics = trainer( |
910 | strategy=textual_inversion_strategy, | 915 | strategy=lora_strategy, |
916 | pti_mode=True, | ||
911 | project="pti", | 917 | project="pti", |
912 | train_dataloader=pti_datamodule.train_dataloader, | 918 | train_dataloader=pti_datamodule.train_dataloader, |
913 | val_dataloader=pti_datamodule.val_dataloader, | 919 | val_dataloader=pti_datamodule.val_dataloader, |
@@ -919,11 +925,6 @@ def main(): | |||
919 | sample_output_dir=pti_sample_output_dir, | 925 | sample_output_dir=pti_sample_output_dir, |
920 | checkpoint_output_dir=pti_checkpoint_output_dir, | 926 | checkpoint_output_dir=pti_checkpoint_output_dir, |
921 | sample_frequency=math.inf, | 927 | sample_frequency=math.inf, |
922 | placeholder_tokens=args.placeholder_tokens, | ||
923 | placeholder_token_ids=placeholder_token_ids, | ||
924 | use_emb_decay=args.use_emb_decay, | ||
925 | emb_decay_target=args.emb_decay_target, | ||
926 | emb_decay=args.emb_decay, | ||
927 | ) | 928 | ) |
928 | 929 | ||
929 | plot_metrics(metrics, pti_output_dir / "lr.png") | 930 | plot_metrics(metrics, pti_output_dir / "lr.png") |
@@ -952,13 +953,21 @@ def main(): | |||
952 | lora_optimizer = create_optimizer( | 953 | lora_optimizer = create_optimizer( |
953 | [ | 954 | [ |
954 | { | 955 | { |
955 | "params": unet.parameters(), | 956 | "params": ( |
957 | param | ||
958 | for param in unet.parameters() | ||
959 | if param.requires_grad | ||
960 | ), | ||
956 | "lr": args.learning_rate_unet, | 961 | "lr": args.learning_rate_unet, |
957 | }, | 962 | }, |
958 | { | 963 | { |
959 | "params": itertools.chain( | 964 | "params": ( |
960 | text_encoder.text_model.encoder.parameters(), | 965 | param |
961 | text_encoder.text_model.final_layer_norm.parameters(), | 966 | for param in itertools.chain( |
967 | text_encoder.text_model.encoder.parameters(), | ||
968 | text_encoder.text_model.final_layer_norm.parameters(), | ||
969 | ) | ||
970 | if param.requires_grad | ||
962 | ), | 971 | ), |
963 | "lr": args.learning_rate_text, | 972 | "lr": args.learning_rate_text, |
964 | }, | 973 | }, |
@@ -990,12 +999,6 @@ def main(): | |||
990 | sample_output_dir=lora_sample_output_dir, | 999 | sample_output_dir=lora_sample_output_dir, |
991 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1000 | checkpoint_output_dir=lora_checkpoint_output_dir, |
992 | sample_frequency=lora_sample_frequency, | 1001 | sample_frequency=lora_sample_frequency, |
993 | placeholder_tokens=args.placeholder_tokens, | ||
994 | placeholder_token_ids=placeholder_token_ids, | ||
995 | use_emb_decay=args.use_emb_decay, | ||
996 | emb_decay_target=args.emb_decay_target, | ||
997 | emb_decay=args.emb_decay, | ||
998 | max_grad_norm=args.max_grad_norm, | ||
999 | ) | 1002 | ) |
1000 | 1003 | ||
1001 | plot_metrics(metrics, lora_output_dir / "lr.png") | 1004 | plot_metrics(metrics, lora_output_dir / "lr.png") |