diff options
| -rw-r--r-- | train_lora.py | 71 |
1 files changed, 39 insertions, 32 deletions
diff --git a/train_lora.py b/train_lora.py index 5b0a292..9f17495 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -445,6 +445,12 @@ def parse_args(): | |||
| 445 | help="How often to save a checkpoint and sample image", | 445 | help="How often to save a checkpoint and sample image", |
| 446 | ) | 446 | ) |
| 447 | parser.add_argument( | 447 | parser.add_argument( |
| 448 | "--pti_sample_frequency", | ||
| 449 | type=int, | ||
| 450 | default=1, | ||
| 451 | help="How often to save a checkpoint and sample image", | ||
| 452 | ) | ||
| 453 | parser.add_argument( | ||
| 448 | "--sample_image_size", | 454 | "--sample_image_size", |
| 449 | type=int, | 455 | type=int, |
| 450 | default=768, | 456 | default=768, |
| @@ -887,47 +893,48 @@ def main(): | |||
| 887 | pti_datamodule.setup() | 893 | pti_datamodule.setup() |
| 888 | 894 | ||
| 889 | num_pti_epochs = args.num_pti_epochs | 895 | num_pti_epochs = args.num_pti_epochs |
| 890 | pti_sample_frequency = args.sample_frequency | 896 | pti_sample_frequency = args.pti_sample_frequency |
| 891 | if num_pti_epochs is None: | 897 | if num_pti_epochs is None: |
| 892 | num_pti_epochs = math.ceil( | 898 | num_pti_epochs = math.ceil( |
| 893 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 899 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
| 894 | ) * args.pti_gradient_accumulation_steps | 900 | ) * args.pti_gradient_accumulation_steps |
| 895 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) | 901 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) |
| 896 | 902 | ||
| 897 | pti_optimizer = create_optimizer( | 903 | if num_pti_epochs > 0: |
| 898 | [ | 904 | pti_optimizer = create_optimizer( |
| 899 | { | 905 | [ |
| 900 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 906 | { |
| 901 | "lr": args.learning_rate_pti, | 907 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
| 902 | "weight_decay": 0, | 908 | "lr": args.learning_rate_pti, |
| 903 | }, | 909 | "weight_decay": 0, |
| 904 | ] | 910 | }, |
| 905 | ) | 911 | ] |
| 912 | ) | ||
| 906 | 913 | ||
| 907 | pti_lr_scheduler = create_lr_scheduler( | 914 | pti_lr_scheduler = create_lr_scheduler( |
| 908 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 915 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
| 909 | optimizer=pti_optimizer, | 916 | optimizer=pti_optimizer, |
| 910 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | 917 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), |
| 911 | train_epochs=num_pti_epochs, | 918 | train_epochs=num_pti_epochs, |
| 912 | ) | 919 | ) |
| 913 | 920 | ||
| 914 | metrics = trainer( | 921 | metrics = trainer( |
| 915 | strategy=lora_strategy, | 922 | strategy=lora_strategy, |
| 916 | pti_mode=True, | 923 | pti_mode=True, |
| 917 | project="pti", | 924 | project="pti", |
| 918 | train_dataloader=pti_datamodule.train_dataloader, | 925 | train_dataloader=pti_datamodule.train_dataloader, |
| 919 | val_dataloader=pti_datamodule.val_dataloader, | 926 | val_dataloader=pti_datamodule.val_dataloader, |
| 920 | optimizer=pti_optimizer, | 927 | optimizer=pti_optimizer, |
| 921 | lr_scheduler=pti_lr_scheduler, | 928 | lr_scheduler=pti_lr_scheduler, |
| 922 | num_train_epochs=num_pti_epochs, | 929 | num_train_epochs=num_pti_epochs, |
| 923 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 930 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
| 924 | # -- | 931 | # -- |
| 925 | sample_output_dir=pti_sample_output_dir, | 932 | sample_output_dir=pti_sample_output_dir, |
| 926 | checkpoint_output_dir=pti_checkpoint_output_dir, | 933 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 927 | sample_frequency=math.inf, | 934 | sample_frequency=pti_sample_frequency, |
| 928 | ) | 935 | ) |
| 929 | 936 | ||
| 930 | plot_metrics(metrics, pti_output_dir / "lr.png") | 937 | plot_metrics(metrics, pti_output_dir / "lr.png") |
| 931 | 938 | ||
| 932 | # LORA | 939 | # LORA |
| 933 | # -------------------------------------------------------------------------------- | 940 | # -------------------------------------------------------------------------------- |
