diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-07 11:31:21 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-07 11:31:21 +0200 |
| commit | 37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f (patch) | |
| tree | 1f18d01cc23418789b6b4b00b38edc0a80b6214a /train_lora.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f.tar.gz textual-inversion-diff-37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f.tar.bz2 textual-inversion-diff-37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f.zip | |
Run PTI only if placeholder tokens arg isn't empty
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 103 |
1 files changed, 52 insertions, 51 deletions
diff --git a/train_lora.py b/train_lora.py index 6de3a75..daf1f6c 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -867,62 +867,63 @@ def main(): | |||
| 867 | # PTI | 867 | # PTI |
| 868 | # -------------------------------------------------------------------------------- | 868 | # -------------------------------------------------------------------------------- |
| 869 | 869 | ||
| 870 | pti_output_dir = output_dir / "pti" | 870 | if len(args.placeholder_tokens) != 0: |
| 871 | pti_checkpoint_output_dir = pti_output_dir / "model" | 871 | pti_output_dir = output_dir / "pti" |
| 872 | pti_sample_output_dir = pti_output_dir / "samples" | 872 | pti_checkpoint_output_dir = pti_output_dir / "model" |
| 873 | pti_sample_output_dir = pti_output_dir / "samples" | ||
| 873 | 874 | ||
| 874 | pti_datamodule = create_datamodule( | 875 | pti_datamodule = create_datamodule( |
| 875 | batch_size=args.pti_batch_size, | 876 | batch_size=args.pti_batch_size, |
| 876 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 877 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
| 877 | ) | 878 | ) |
| 878 | pti_datamodule.setup() | 879 | pti_datamodule.setup() |
| 879 | 880 | ||
| 880 | num_pti_epochs = args.num_pti_epochs | 881 | num_pti_epochs = args.num_pti_epochs |
| 881 | pti_sample_frequency = args.sample_frequency | 882 | pti_sample_frequency = args.sample_frequency |
| 882 | if num_pti_epochs is None: | 883 | if num_pti_epochs is None: |
| 883 | num_pti_epochs = math.ceil( | 884 | num_pti_epochs = math.ceil( |
| 884 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 885 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
| 885 | ) * args.pti_gradient_accumulation_steps | 886 | ) * args.pti_gradient_accumulation_steps |
| 886 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | 887 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) |
| 887 | 888 | ||
| 888 | pti_optimizer = create_optimizer( | 889 | pti_optimizer = create_optimizer( |
| 889 | [ | 890 | [ |
| 890 | { | 891 | { |
| 891 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 892 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
| 892 | "lr": args.learning_rate_pti, | 893 | "lr": args.learning_rate_pti, |
| 893 | "weight_decay": 0, | 894 | "weight_decay": 0, |
| 894 | }, | 895 | }, |
| 895 | ] | 896 | ] |
| 896 | ) | 897 | ) |
| 897 | 898 | ||
| 898 | pti_lr_scheduler = create_lr_scheduler( | 899 | pti_lr_scheduler = create_lr_scheduler( |
| 899 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 900 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
| 900 | optimizer=pti_optimizer, | 901 | optimizer=pti_optimizer, |
| 901 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | 902 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), |
| 902 | train_epochs=num_pti_epochs, | 903 | train_epochs=num_pti_epochs, |
| 903 | ) | 904 | ) |
| 904 | 905 | ||
| 905 | metrics = trainer( | 906 | metrics = trainer( |
| 906 | strategy=textual_inversion_strategy, | 907 | strategy=textual_inversion_strategy, |
| 907 | project="pti", | 908 | project="pti", |
| 908 | train_dataloader=pti_datamodule.train_dataloader, | 909 | train_dataloader=pti_datamodule.train_dataloader, |
| 909 | val_dataloader=pti_datamodule.val_dataloader, | 910 | val_dataloader=pti_datamodule.val_dataloader, |
| 910 | optimizer=pti_optimizer, | 911 | optimizer=pti_optimizer, |
| 911 | lr_scheduler=pti_lr_scheduler, | 912 | lr_scheduler=pti_lr_scheduler, |
| 912 | num_train_epochs=num_pti_epochs, | 913 | num_train_epochs=num_pti_epochs, |
| 913 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 914 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
| 914 | # -- | 915 | # -- |
| 915 | sample_output_dir=pti_sample_output_dir, | 916 | sample_output_dir=pti_sample_output_dir, |
| 916 | checkpoint_output_dir=pti_checkpoint_output_dir, | 917 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 917 | sample_frequency=pti_sample_frequency, | 918 | sample_frequency=pti_sample_frequency, |
| 918 | placeholder_tokens=args.placeholder_tokens, | 919 | placeholder_tokens=args.placeholder_tokens, |
| 919 | placeholder_token_ids=placeholder_token_ids, | 920 | placeholder_token_ids=placeholder_token_ids, |
| 920 | use_emb_decay=args.use_emb_decay, | 921 | use_emb_decay=args.use_emb_decay, |
| 921 | emb_decay_target=args.emb_decay_target, | 922 | emb_decay_target=args.emb_decay_target, |
| 922 | emb_decay=args.emb_decay, | 923 | emb_decay=args.emb_decay, |
| 923 | ) | 924 | ) |
| 924 | 925 | ||
| 925 | plot_metrics(metrics, output_dir/"lr.png") | 926 | plot_metrics(metrics, pti_output_dir / "lr.png") |
| 926 | 927 | ||
| 927 | # LORA | 928 | # LORA |
| 928 | # -------------------------------------------------------------------------------- | 929 | # -------------------------------------------------------------------------------- |
| @@ -994,7 +995,7 @@ def main(): | |||
| 994 | max_grad_norm=args.max_grad_norm, | 995 | max_grad_norm=args.max_grad_norm, |
| 995 | ) | 996 | ) |
| 996 | 997 | ||
| 997 | plot_metrics(metrics, output_dir/"lr.png") | 998 | plot_metrics(metrics, lora_output_dir / "lr.png") |
| 998 | 999 | ||
| 999 | 1000 | ||
| 1000 | if __name__ == "__main__": | 1001 | if __name__ == "__main__": |
