summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-07 17:42:23 +0200
committerVolpeon <git@volpeon.ink>2023-04-07 17:42:23 +0200
commit4ca5d3bf8710c479dab994a7243b2dd9232bab3d (patch)
tree2a64e092bf0e8f2a40aa6a1ef8c5f1eecb6f9708 /train_lora.py
parentFix (diff)
downloadtextual-inversion-diff-4ca5d3bf8710c479dab994a7243b2dd9232bab3d.tar.gz
textual-inversion-diff-4ca5d3bf8710c479dab994a7243b2dd9232bab3d.tar.bz2
textual-inversion-diff-4ca5d3bf8710c479dab994a7243b2dd9232bab3d.zip
Fixed Lora PTI
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py37
1 files changed, 20 insertions, 17 deletions
diff --git a/train_lora.py b/train_lora.py
index 476efcf..5b0a292 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -21,7 +21,6 @@ from data.csv import VlpnDataModule, keyword_filter
21from training.functional import train, add_placeholder_tokens, get_models 21from training.functional import train, add_placeholder_tokens, get_models
22from training.lr import plot_metrics 22from training.lr import plot_metrics
23from training.strategy.lora import lora_strategy 23from training.strategy.lora import lora_strategy
24from training.strategy.ti import textual_inversion_strategy
25from training.optimization import get_scheduler 24from training.optimization import get_scheduler
26from training.util import save_args 25from training.util import save_args
27 26
@@ -829,6 +828,12 @@ def main():
829 sample_num_batches=args.sample_batches, 828 sample_num_batches=args.sample_batches,
830 sample_num_steps=args.sample_steps, 829 sample_num_steps=args.sample_steps,
831 sample_image_size=args.sample_image_size, 830 sample_image_size=args.sample_image_size,
831 placeholder_tokens=args.placeholder_tokens,
832 placeholder_token_ids=placeholder_token_ids,
833 use_emb_decay=args.use_emb_decay,
834 emb_decay_target=args.emb_decay_target,
835 emb_decay=args.emb_decay,
836 max_grad_norm=args.max_grad_norm,
832 ) 837 )
833 838
834 create_datamodule = partial( 839 create_datamodule = partial(
@@ -907,7 +912,8 @@ def main():
907 ) 912 )
908 913
909 metrics = trainer( 914 metrics = trainer(
910 strategy=textual_inversion_strategy, 915 strategy=lora_strategy,
916 pti_mode=True,
911 project="pti", 917 project="pti",
912 train_dataloader=pti_datamodule.train_dataloader, 918 train_dataloader=pti_datamodule.train_dataloader,
913 val_dataloader=pti_datamodule.val_dataloader, 919 val_dataloader=pti_datamodule.val_dataloader,
@@ -919,11 +925,6 @@ def main():
919 sample_output_dir=pti_sample_output_dir, 925 sample_output_dir=pti_sample_output_dir,
920 checkpoint_output_dir=pti_checkpoint_output_dir, 926 checkpoint_output_dir=pti_checkpoint_output_dir,
921 sample_frequency=math.inf, 927 sample_frequency=math.inf,
922 placeholder_tokens=args.placeholder_tokens,
923 placeholder_token_ids=placeholder_token_ids,
924 use_emb_decay=args.use_emb_decay,
925 emb_decay_target=args.emb_decay_target,
926 emb_decay=args.emb_decay,
927 ) 928 )
928 929
929 plot_metrics(metrics, pti_output_dir / "lr.png") 930 plot_metrics(metrics, pti_output_dir / "lr.png")
@@ -952,13 +953,21 @@ def main():
952 lora_optimizer = create_optimizer( 953 lora_optimizer = create_optimizer(
953 [ 954 [
954 { 955 {
955 "params": unet.parameters(), 956 "params": (
957 param
958 for param in unet.parameters()
959 if param.requires_grad
960 ),
956 "lr": args.learning_rate_unet, 961 "lr": args.learning_rate_unet,
957 }, 962 },
958 { 963 {
959 "params": itertools.chain( 964 "params": (
960 text_encoder.text_model.encoder.parameters(), 965 param
961 text_encoder.text_model.final_layer_norm.parameters(), 966 for param in itertools.chain(
967 text_encoder.text_model.encoder.parameters(),
968 text_encoder.text_model.final_layer_norm.parameters(),
969 )
970 if param.requires_grad
962 ), 971 ),
963 "lr": args.learning_rate_text, 972 "lr": args.learning_rate_text,
964 }, 973 },
@@ -990,12 +999,6 @@ def main():
990 sample_output_dir=lora_sample_output_dir, 999 sample_output_dir=lora_sample_output_dir,
991 checkpoint_output_dir=lora_checkpoint_output_dir, 1000 checkpoint_output_dir=lora_checkpoint_output_dir,
992 sample_frequency=lora_sample_frequency, 1001 sample_frequency=lora_sample_frequency,
993 placeholder_tokens=args.placeholder_tokens,
994 placeholder_token_ids=placeholder_token_ids,
995 use_emb_decay=args.use_emb_decay,
996 emb_decay_target=args.emb_decay_target,
997 emb_decay=args.emb_decay,
998 max_grad_norm=args.max_grad_norm,
999 ) 1002 )
1000 1003
1001 plot_metrics(metrics, lora_output_dir / "lr.png") 1004 plot_metrics(metrics, lora_output_dir / "lr.png")