summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-07 21:41:53 +0200
committerVolpeon <git@volpeon.ink>2023-04-07 21:41:53 +0200
commit85953e46c0d36658293b1cd39e26f5f550b173f8 (patch)
tree425cf3b8ad2bdadd1c66ea054e779d20b2138457
parentFixed Lora PTI (diff)
downloadtextual-inversion-diff-85953e46c0d36658293b1cd39e26f5f550b173f8.tar.gz
textual-inversion-diff-85953e46c0d36658293b1cd39e26f5f550b173f8.tar.bz2
textual-inversion-diff-85953e46c0d36658293b1cd39e26f5f550b173f8.zip
Fix
-rw-r--r--train_lora.py77
1 files changed, 42 insertions, 35 deletions
diff --git a/train_lora.py b/train_lora.py
index 5b0a292..9f17495 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -445,6 +445,12 @@ def parse_args():
445 help="How often to save a checkpoint and sample image", 445 help="How often to save a checkpoint and sample image",
446 ) 446 )
447 parser.add_argument( 447 parser.add_argument(
448 "--pti_sample_frequency",
449 type=int,
450 default=1,
451 help="How often to save a checkpoint and sample image",
452 )
453 parser.add_argument(
448 "--sample_image_size", 454 "--sample_image_size",
449 type=int, 455 type=int,
450 default=768, 456 default=768,
@@ -887,47 +893,48 @@ def main():
887 pti_datamodule.setup() 893 pti_datamodule.setup()
888 894
889 num_pti_epochs = args.num_pti_epochs 895 num_pti_epochs = args.num_pti_epochs
890 pti_sample_frequency = args.sample_frequency 896 pti_sample_frequency = args.pti_sample_frequency
891 if num_pti_epochs is None: 897 if num_pti_epochs is None:
892 num_pti_epochs = math.ceil( 898 num_pti_epochs = math.ceil(
893 args.num_pti_steps / len(pti_datamodule.train_dataset) 899 args.num_pti_steps / len(pti_datamodule.train_dataset)
894 ) * args.pti_gradient_accumulation_steps 900 ) * args.pti_gradient_accumulation_steps
895 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) 901 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps))
896 902
897 pti_optimizer = create_optimizer( 903 if num_pti_epochs > 0:
898 [ 904 pti_optimizer = create_optimizer(
899 { 905 [
900 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 906 {
901 "lr": args.learning_rate_pti, 907 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
902 "weight_decay": 0, 908 "lr": args.learning_rate_pti,
903 }, 909 "weight_decay": 0,
904 ] 910 },
905 ) 911 ]
906 912 )
907 pti_lr_scheduler = create_lr_scheduler( 913
908 gradient_accumulation_steps=args.pti_gradient_accumulation_steps, 914 pti_lr_scheduler = create_lr_scheduler(
909 optimizer=pti_optimizer, 915 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
910 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), 916 optimizer=pti_optimizer,
911 train_epochs=num_pti_epochs, 917 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
912 ) 918 train_epochs=num_pti_epochs,
913 919 )
914 metrics = trainer( 920
915 strategy=lora_strategy, 921 metrics = trainer(
916 pti_mode=True, 922 strategy=lora_strategy,
917 project="pti", 923 pti_mode=True,
918 train_dataloader=pti_datamodule.train_dataloader, 924 project="pti",
919 val_dataloader=pti_datamodule.val_dataloader, 925 train_dataloader=pti_datamodule.train_dataloader,
920 optimizer=pti_optimizer, 926 val_dataloader=pti_datamodule.val_dataloader,
921 lr_scheduler=pti_lr_scheduler, 927 optimizer=pti_optimizer,
922 num_train_epochs=num_pti_epochs, 928 lr_scheduler=pti_lr_scheduler,
923 gradient_accumulation_steps=args.pti_gradient_accumulation_steps, 929 num_train_epochs=num_pti_epochs,
924 # -- 930 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
925 sample_output_dir=pti_sample_output_dir, 931 # --
926 checkpoint_output_dir=pti_checkpoint_output_dir, 932 sample_output_dir=pti_sample_output_dir,
927 sample_frequency=math.inf, 933 checkpoint_output_dir=pti_checkpoint_output_dir,
928 ) 934 sample_frequency=pti_sample_frequency,
929 935 )
930 plot_metrics(metrics, pti_output_dir / "lr.png") 936
937 plot_metrics(metrics, pti_output_dir / "lr.png")
931 938
932 # LORA 939 # LORA
933 # -------------------------------------------------------------------------------- 940 # --------------------------------------------------------------------------------