diff options
| -rw-r--r-- | train_lora.py | 103 | ||||
| -rw-r--r-- | train_ti.py | 12 |
2 files changed, 58 insertions, 57 deletions
diff --git a/train_lora.py b/train_lora.py index 6de3a75..daf1f6c 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -867,62 +867,63 @@ def main(): | |||
| 867 | # PTI | 867 | # PTI |
| 868 | # -------------------------------------------------------------------------------- | 868 | # -------------------------------------------------------------------------------- |
| 869 | 869 | ||
| 870 | pti_output_dir = output_dir / "pti" | 870 | if len(args.placeholder_tokens) != 0: |
| 871 | pti_checkpoint_output_dir = pti_output_dir / "model" | 871 | pti_output_dir = output_dir / "pti" |
| 872 | pti_sample_output_dir = pti_output_dir / "samples" | 872 | pti_checkpoint_output_dir = pti_output_dir / "model" |
| 873 | pti_sample_output_dir = pti_output_dir / "samples" | ||
| 873 | 874 | ||
| 874 | pti_datamodule = create_datamodule( | 875 | pti_datamodule = create_datamodule( |
| 875 | batch_size=args.pti_batch_size, | 876 | batch_size=args.pti_batch_size, |
| 876 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 877 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
| 877 | ) | 878 | ) |
| 878 | pti_datamodule.setup() | 879 | pti_datamodule.setup() |
| 879 | 880 | ||
| 880 | num_pti_epochs = args.num_pti_epochs | 881 | num_pti_epochs = args.num_pti_epochs |
| 881 | pti_sample_frequency = args.sample_frequency | 882 | pti_sample_frequency = args.sample_frequency |
| 882 | if num_pti_epochs is None: | 883 | if num_pti_epochs is None: |
| 883 | num_pti_epochs = math.ceil( | 884 | num_pti_epochs = math.ceil( |
| 884 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 885 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
| 885 | ) * args.pti_gradient_accumulation_steps | 886 | ) * args.pti_gradient_accumulation_steps |
| 886 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | 887 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) |
| 887 | 888 | ||
| 888 | pti_optimizer = create_optimizer( | 889 | pti_optimizer = create_optimizer( |
| 889 | [ | 890 | [ |
| 890 | { | 891 | { |
| 891 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 892 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
| 892 | "lr": args.learning_rate_pti, | 893 | "lr": args.learning_rate_pti, |
| 893 | "weight_decay": 0, | 894 | "weight_decay": 0, |
| 894 | }, | 895 | }, |
| 895 | ] | 896 | ] |
| 896 | ) | 897 | ) |
| 897 | 898 | ||
| 898 | pti_lr_scheduler = create_lr_scheduler( | 899 | pti_lr_scheduler = create_lr_scheduler( |
| 899 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 900 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
| 900 | optimizer=pti_optimizer, | 901 | optimizer=pti_optimizer, |
| 901 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | 902 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), |
| 902 | train_epochs=num_pti_epochs, | 903 | train_epochs=num_pti_epochs, |
| 903 | ) | 904 | ) |
| 904 | 905 | ||
| 905 | metrics = trainer( | 906 | metrics = trainer( |
| 906 | strategy=textual_inversion_strategy, | 907 | strategy=textual_inversion_strategy, |
| 907 | project="pti", | 908 | project="pti", |
| 908 | train_dataloader=pti_datamodule.train_dataloader, | 909 | train_dataloader=pti_datamodule.train_dataloader, |
| 909 | val_dataloader=pti_datamodule.val_dataloader, | 910 | val_dataloader=pti_datamodule.val_dataloader, |
| 910 | optimizer=pti_optimizer, | 911 | optimizer=pti_optimizer, |
| 911 | lr_scheduler=pti_lr_scheduler, | 912 | lr_scheduler=pti_lr_scheduler, |
| 912 | num_train_epochs=num_pti_epochs, | 913 | num_train_epochs=num_pti_epochs, |
| 913 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 914 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
| 914 | # -- | 915 | # -- |
| 915 | sample_output_dir=pti_sample_output_dir, | 916 | sample_output_dir=pti_sample_output_dir, |
| 916 | checkpoint_output_dir=pti_checkpoint_output_dir, | 917 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 917 | sample_frequency=pti_sample_frequency, | 918 | sample_frequency=pti_sample_frequency, |
| 918 | placeholder_tokens=args.placeholder_tokens, | 919 | placeholder_tokens=args.placeholder_tokens, |
| 919 | placeholder_token_ids=placeholder_token_ids, | 920 | placeholder_token_ids=placeholder_token_ids, |
| 920 | use_emb_decay=args.use_emb_decay, | 921 | use_emb_decay=args.use_emb_decay, |
| 921 | emb_decay_target=args.emb_decay_target, | 922 | emb_decay_target=args.emb_decay_target, |
| 922 | emb_decay=args.emb_decay, | 923 | emb_decay=args.emb_decay, |
| 923 | ) | 924 | ) |
| 924 | 925 | ||
| 925 | plot_metrics(metrics, output_dir/"lr.png") | 926 | plot_metrics(metrics, pti_output_dir / "lr.png") |
| 926 | 927 | ||
| 927 | # LORA | 928 | # LORA |
| 928 | # -------------------------------------------------------------------------------- | 929 | # -------------------------------------------------------------------------------- |
| @@ -994,7 +995,7 @@ def main(): | |||
| 994 | max_grad_norm=args.max_grad_norm, | 995 | max_grad_norm=args.max_grad_norm, |
| 995 | ) | 996 | ) |
| 996 | 997 | ||
| 997 | plot_metrics(metrics, output_dir/"lr.png") | 998 | plot_metrics(metrics, lora_output_dir / "lr.png") |
| 998 | 999 | ||
| 999 | 1000 | ||
| 1000 | if __name__ == "__main__": | 1001 | if __name__ == "__main__": |
diff --git a/train_ti.py b/train_ti.py index 344b412..c1c0eed 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -602,7 +602,7 @@ def main(): | |||
| 602 | elif args.mixed_precision == "bf16": | 602 | elif args.mixed_precision == "bf16": |
| 603 | weight_dtype = torch.bfloat16 | 603 | weight_dtype = torch.bfloat16 |
| 604 | 604 | ||
| 605 | logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) | 605 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
| 606 | 606 | ||
| 607 | if args.seed is None: | 607 | if args.seed is None: |
| 608 | args.seed = torch.random.seed() >> 32 | 608 | args.seed = torch.random.seed() >> 32 |
| @@ -743,7 +743,7 @@ def main(): | |||
| 743 | else: | 743 | else: |
| 744 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 744 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 745 | 745 | ||
| 746 | checkpoint_output_dir = output_dir/"checkpoints" | 746 | checkpoint_output_dir = output_dir / "checkpoints" |
| 747 | 747 | ||
| 748 | trainer = partial( | 748 | trainer = partial( |
| 749 | train, | 749 | train, |
| @@ -782,11 +782,11 @@ def main(): | |||
| 782 | 782 | ||
| 783 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 783 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
| 784 | if len(placeholder_tokens) == 1: | 784 | if len(placeholder_tokens) == 1: |
| 785 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" | 785 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" |
| 786 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" | 786 | metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" |
| 787 | else: | 787 | else: |
| 788 | sample_output_dir = output_dir/"samples" | 788 | sample_output_dir = output_dir / "samples" |
| 789 | metrics_output_file = output_dir/f"lr.png" | 789 | metrics_output_file = output_dir / "lr.png" |
| 790 | 790 | ||
| 791 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 791 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 792 | tokenizer=tokenizer, | 792 | tokenizer=tokenizer, |
