diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 77 |
1 files changed, 42 insertions, 35 deletions
diff --git a/train_lora.py b/train_lora.py index 5b0a292..9f17495 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -445,6 +445,12 @@ def parse_args(): | |||
445 | help="How often to save a checkpoint and sample image", | 445 | help="How often to save a checkpoint and sample image", |
446 | ) | 446 | ) |
447 | parser.add_argument( | 447 | parser.add_argument( |
448 | "--pti_sample_frequency", | ||
449 | type=int, | ||
450 | default=1, | ||
451 | help="How often to save a checkpoint and sample image", | ||
452 | ) | ||
453 | parser.add_argument( | ||
448 | "--sample_image_size", | 454 | "--sample_image_size", |
449 | type=int, | 455 | type=int, |
450 | default=768, | 456 | default=768, |
@@ -887,47 +893,48 @@ def main(): | |||
887 | pti_datamodule.setup() | 893 | pti_datamodule.setup() |
888 | 894 | ||
889 | num_pti_epochs = args.num_pti_epochs | 895 | num_pti_epochs = args.num_pti_epochs |
890 | pti_sample_frequency = args.sample_frequency | 896 | pti_sample_frequency = args.pti_sample_frequency |
891 | if num_pti_epochs is None: | 897 | if num_pti_epochs is None: |
892 | num_pti_epochs = math.ceil( | 898 | num_pti_epochs = math.ceil( |
893 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 899 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
894 | ) * args.pti_gradient_accumulation_steps | 900 | ) * args.pti_gradient_accumulation_steps |
895 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) | 901 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) |
896 | 902 | ||
897 | pti_optimizer = create_optimizer( | 903 | if num_pti_epochs > 0: |
898 | [ | 904 | pti_optimizer = create_optimizer( |
899 | { | 905 | [ |
900 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 906 | { |
901 | "lr": args.learning_rate_pti, | 907 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
902 | "weight_decay": 0, | 908 | "lr": args.learning_rate_pti, |
903 | }, | 909 | "weight_decay": 0, |
904 | ] | 910 | }, |
905 | ) | 911 | ] |
906 | 912 | ) | |
907 | pti_lr_scheduler = create_lr_scheduler( | 913 | |
908 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 914 | pti_lr_scheduler = create_lr_scheduler( |
909 | optimizer=pti_optimizer, | 915 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
910 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | 916 | optimizer=pti_optimizer, |
911 | train_epochs=num_pti_epochs, | 917 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), |
912 | ) | 918 | train_epochs=num_pti_epochs, |
913 | 919 | ) | |
914 | metrics = trainer( | 920 | |
915 | strategy=lora_strategy, | 921 | metrics = trainer( |
916 | pti_mode=True, | 922 | strategy=lora_strategy, |
917 | project="pti", | 923 | pti_mode=True, |
918 | train_dataloader=pti_datamodule.train_dataloader, | 924 | project="pti", |
919 | val_dataloader=pti_datamodule.val_dataloader, | 925 | train_dataloader=pti_datamodule.train_dataloader, |
920 | optimizer=pti_optimizer, | 926 | val_dataloader=pti_datamodule.val_dataloader, |
921 | lr_scheduler=pti_lr_scheduler, | 927 | optimizer=pti_optimizer, |
922 | num_train_epochs=num_pti_epochs, | 928 | lr_scheduler=pti_lr_scheduler, |
923 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 929 | num_train_epochs=num_pti_epochs, |
924 | # -- | 930 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
925 | sample_output_dir=pti_sample_output_dir, | 931 | # -- |
926 | checkpoint_output_dir=pti_checkpoint_output_dir, | 932 | sample_output_dir=pti_sample_output_dir, |
927 | sample_frequency=math.inf, | 933 | checkpoint_output_dir=pti_checkpoint_output_dir, |
928 | ) | 934 | sample_frequency=pti_sample_frequency, |
929 | 935 | ) | |
930 | plot_metrics(metrics, pti_output_dir / "lr.png") | 936 | |
937 | plot_metrics(metrics, pti_output_dir / "lr.png") | ||
931 | 938 | ||
932 | # LORA | 939 | # LORA |
933 | # -------------------------------------------------------------------------------- | 940 | # -------------------------------------------------------------------------------- |