diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 140 |
1 files changed, 12 insertions, 128 deletions
diff --git a/train_lora.py b/train_lora.py index 8dbe45b..6e21634 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -159,12 +159,6 @@ def parse_args(): | |||
159 | help="Tag dropout probability.", | 159 | help="Tag dropout probability.", |
160 | ) | 160 | ) |
161 | parser.add_argument( | 161 | parser.add_argument( |
162 | "--pti_tag_dropout", | ||
163 | type=float, | ||
164 | default=0, | ||
165 | help="Tag dropout probability.", | ||
166 | ) | ||
167 | parser.add_argument( | ||
168 | "--no_tag_shuffle", | 162 | "--no_tag_shuffle", |
169 | action="store_true", | 163 | action="store_true", |
170 | help="Shuffle tags.", | 164 | help="Shuffle tags.", |
@@ -236,28 +230,12 @@ def parse_args(): | |||
236 | default=2000 | 230 | default=2000 |
237 | ) | 231 | ) |
238 | parser.add_argument( | 232 | parser.add_argument( |
239 | "--num_pti_epochs", | ||
240 | type=int, | ||
241 | default=None | ||
242 | ) | ||
243 | parser.add_argument( | ||
244 | "--num_pti_steps", | ||
245 | type=int, | ||
246 | default=500 | ||
247 | ) | ||
248 | parser.add_argument( | ||
249 | "--gradient_accumulation_steps", | 233 | "--gradient_accumulation_steps", |
250 | type=int, | 234 | type=int, |
251 | default=1, | 235 | default=1, |
252 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 236 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
253 | ) | 237 | ) |
254 | parser.add_argument( | 238 | parser.add_argument( |
255 | "--pti_gradient_accumulation_steps", | ||
256 | type=int, | ||
257 | default=1, | ||
258 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
259 | ) | ||
260 | parser.add_argument( | ||
261 | "--lora_r", | 239 | "--lora_r", |
262 | type=int, | 240 | type=int, |
263 | default=8, | 241 | default=8, |
@@ -323,12 +301,6 @@ def parse_args(): | |||
323 | help="Initial learning rate (after the potential warmup period) to use.", | 301 | help="Initial learning rate (after the potential warmup period) to use.", |
324 | ) | 302 | ) |
325 | parser.add_argument( | 303 | parser.add_argument( |
326 | "--learning_rate_pti", | ||
327 | type=float, | ||
328 | default=1e-4, | ||
329 | help="Initial learning rate (after the potential warmup period) to use.", | ||
330 | ) | ||
331 | parser.add_argument( | ||
332 | "--learning_rate_emb", | 304 | "--learning_rate_emb", |
333 | type=float, | 305 | type=float, |
334 | default=1e-5, | 306 | default=1e-5, |
@@ -467,12 +439,6 @@ def parse_args(): | |||
467 | help="How often to save a checkpoint and sample image", | 439 | help="How often to save a checkpoint and sample image", |
468 | ) | 440 | ) |
469 | parser.add_argument( | 441 | parser.add_argument( |
470 | "--pti_sample_frequency", | ||
471 | type=int, | ||
472 | default=1, | ||
473 | help="How often to save a checkpoint and sample image", | ||
474 | ) | ||
475 | parser.add_argument( | ||
476 | "--sample_image_size", | 442 | "--sample_image_size", |
477 | type=int, | 443 | type=int, |
478 | default=768, | 444 | default=768, |
@@ -509,12 +475,6 @@ def parse_args(): | |||
509 | help="Batch size (per device) for the training dataloader." | 475 | help="Batch size (per device) for the training dataloader." |
510 | ) | 476 | ) |
511 | parser.add_argument( | 477 | parser.add_argument( |
512 | "--pti_batch_size", | ||
513 | type=int, | ||
514 | default=1, | ||
515 | help="Batch size (per device) for the training dataloader." | ||
516 | ) | ||
517 | parser.add_argument( | ||
518 | "--sample_steps", | 478 | "--sample_steps", |
519 | type=int, | 479 | type=int, |
520 | default=10, | 480 | default=10, |
@@ -527,6 +487,12 @@ def parse_args(): | |||
527 | help="The weight of prior preservation loss." | 487 | help="The weight of prior preservation loss." |
528 | ) | 488 | ) |
529 | parser.add_argument( | 489 | parser.add_argument( |
490 | "--emb_dropout", | ||
491 | type=float, | ||
492 | default=0, | ||
493 | help="Embedding dropout probability.", | ||
494 | ) | ||
495 | parser.add_argument( | ||
530 | "--use_emb_decay", | 496 | "--use_emb_decay", |
531 | action="store_true", | 497 | action="store_true", |
532 | help="Whether to use embedding decay." | 498 | help="Whether to use embedding decay." |
@@ -674,7 +640,7 @@ def main(): | |||
674 | save_args(output_dir, args) | 640 | save_args(output_dir, args) |
675 | 641 | ||
676 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 642 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
677 | args.pretrained_model_name_or_path) | 643 | args.pretrained_model_name_or_path, args.emb_dropout) |
678 | 644 | ||
679 | unet_config = LoraConfig( | 645 | unet_config = LoraConfig( |
680 | r=args.lora_r, | 646 | r=args.lora_r, |
@@ -720,6 +686,7 @@ def main(): | |||
720 | raise ValueError("--embeddings_dir must point to an existing directory") | 686 | raise ValueError("--embeddings_dir must point to an existing directory") |
721 | 687 | ||
722 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 688 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
689 | embeddings.persist() | ||
723 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 690 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
724 | 691 | ||
725 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 692 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
@@ -744,19 +711,14 @@ def main(): | |||
744 | args.learning_rate_text * args.gradient_accumulation_steps * | 711 | args.learning_rate_text * args.gradient_accumulation_steps * |
745 | args.train_batch_size * accelerator.num_processes | 712 | args.train_batch_size * accelerator.num_processes |
746 | ) | 713 | ) |
747 | args.learning_rate_pti = ( | ||
748 | args.learning_rate_pti * args.pti_gradient_accumulation_steps * | ||
749 | args.pti_batch_size * accelerator.num_processes | ||
750 | ) | ||
751 | args.learning_rate_emb = ( | 714 | args.learning_rate_emb = ( |
752 | args.learning_rate_emb * args.pti_gradient_accumulation_steps * | 715 | args.learning_rate_emb * args.gradient_accumulation_steps * |
753 | args.pti_batch_size * accelerator.num_processes | 716 | args.train_batch_size * accelerator.num_processes |
754 | ) | 717 | ) |
755 | 718 | ||
756 | if args.find_lr: | 719 | if args.find_lr: |
757 | args.learning_rate_unet = 1e-6 | 720 | args.learning_rate_unet = 1e-6 |
758 | args.learning_rate_text = 1e-6 | 721 | args.learning_rate_text = 1e-6 |
759 | args.learning_rate_pti = 1e-6 | ||
760 | args.learning_rate_emb = 1e-6 | 722 | args.learning_rate_emb = 1e-6 |
761 | args.lr_scheduler = "exponential_growth" | 723 | args.lr_scheduler = "exponential_growth" |
762 | 724 | ||
@@ -817,7 +779,6 @@ def main(): | |||
817 | args.lr_min_lr = args.learning_rate_unet | 779 | args.lr_min_lr = args.learning_rate_unet |
818 | args.learning_rate_unet = None | 780 | args.learning_rate_unet = None |
819 | args.learning_rate_text = None | 781 | args.learning_rate_text = None |
820 | args.learning_rate_pti = None | ||
821 | args.learning_rate_emb = None | 782 | args.learning_rate_emb = None |
822 | elif args.optimizer == 'dadam': | 783 | elif args.optimizer == 'dadam': |
823 | try: | 784 | try: |
@@ -836,7 +797,6 @@ def main(): | |||
836 | 797 | ||
837 | args.learning_rate_unet = 1.0 | 798 | args.learning_rate_unet = 1.0 |
838 | args.learning_rate_text = 1.0 | 799 | args.learning_rate_text = 1.0 |
839 | args.learning_rate_pti = 1.0 | ||
840 | args.learning_rate_emb = 1.0 | 800 | args.learning_rate_emb = 1.0 |
841 | elif args.optimizer == 'dadan': | 801 | elif args.optimizer == 'dadan': |
842 | try: | 802 | try: |
@@ -853,7 +813,6 @@ def main(): | |||
853 | 813 | ||
854 | args.learning_rate_unet = 1.0 | 814 | args.learning_rate_unet = 1.0 |
855 | args.learning_rate_text = 1.0 | 815 | args.learning_rate_text = 1.0 |
856 | args.learning_rate_pti = 1.0 | ||
857 | args.learning_rate_emb = 1.0 | 816 | args.learning_rate_emb = 1.0 |
858 | else: | 817 | else: |
859 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 818 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
@@ -920,80 +879,6 @@ def main(): | |||
920 | mid_point=args.lr_mid_point, | 879 | mid_point=args.lr_mid_point, |
921 | ) | 880 | ) |
922 | 881 | ||
923 | # PTI | ||
924 | # -------------------------------------------------------------------------------- | ||
925 | |||
926 | if len(args.placeholder_tokens) != 0: | ||
927 | pti_datamodule = create_datamodule( | ||
928 | batch_size=args.pti_batch_size, | ||
929 | dropout=args.pti_tag_dropout, | ||
930 | filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), | ||
931 | ) | ||
932 | pti_datamodule.setup() | ||
933 | |||
934 | num_pti_epochs = args.num_pti_epochs | ||
935 | pti_sample_frequency = args.pti_sample_frequency | ||
936 | if num_pti_epochs is None: | ||
937 | num_pti_epochs = math.ceil( | ||
938 | args.num_pti_steps / len(pti_datamodule.train_dataset) | ||
939 | ) * args.pti_gradient_accumulation_steps | ||
940 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) | ||
941 | |||
942 | if num_pti_epochs > 0: | ||
943 | pti_optimizer = create_optimizer( | ||
944 | [ | ||
945 | { | ||
946 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | ||
947 | "lr": args.learning_rate_pti, | ||
948 | "weight_decay": 0, | ||
949 | }, | ||
950 | ] | ||
951 | ) | ||
952 | |||
953 | pti_lr_scheduler = create_lr_scheduler( | ||
954 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
955 | optimizer=pti_optimizer, | ||
956 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | ||
957 | train_epochs=num_pti_epochs, | ||
958 | ) | ||
959 | |||
960 | continue_training = True | ||
961 | training_iter = 1 | ||
962 | |||
963 | while continue_training: | ||
964 | print("") | ||
965 | print(f"============ PTI cycle {training_iter} ============") | ||
966 | print("") | ||
967 | |||
968 | pti_project = f"pti_{training_iter}" | ||
969 | pti_output_dir = output_dir / pti_project | ||
970 | pti_checkpoint_output_dir = pti_output_dir / "model" | ||
971 | pti_sample_output_dir = pti_output_dir / "samples" | ||
972 | |||
973 | trainer( | ||
974 | strategy=lora_strategy, | ||
975 | pti_mode=True, | ||
976 | project=pti_project, | ||
977 | train_dataloader=pti_datamodule.train_dataloader, | ||
978 | val_dataloader=pti_datamodule.val_dataloader, | ||
979 | optimizer=pti_optimizer, | ||
980 | lr_scheduler=pti_lr_scheduler, | ||
981 | num_train_epochs=num_pti_epochs, | ||
982 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
983 | # -- | ||
984 | group_labels=["emb"], | ||
985 | sample_output_dir=pti_sample_output_dir, | ||
986 | checkpoint_output_dir=pti_checkpoint_output_dir, | ||
987 | sample_frequency=pti_sample_frequency, | ||
988 | ) | ||
989 | |||
990 | response = input("Run another cycle? [y/n] ") | ||
991 | continue_training = response.lower().strip() != "n" | ||
992 | training_iter += 1 | ||
993 | |||
994 | if not args.train_emb: | ||
995 | embeddings.persist() | ||
996 | |||
997 | # LORA | 882 | # LORA |
998 | # -------------------------------------------------------------------------------- | 883 | # -------------------------------------------------------------------------------- |
999 | 884 | ||
@@ -1062,9 +947,8 @@ def main(): | |||
1062 | print("") | 947 | print("") |
1063 | 948 | ||
1064 | lora_project = f"lora_{training_iter}" | 949 | lora_project = f"lora_{training_iter}" |
1065 | lora_output_dir = output_dir / lora_project | 950 | lora_checkpoint_output_dir = output_dir / lora_project / "model" |
1066 | lora_checkpoint_output_dir = lora_output_dir / "model" | 951 | lora_sample_output_dir = output_dir / lora_project / "samples" |
1067 | lora_sample_output_dir = lora_output_dir / "samples" | ||
1068 | 952 | ||
1069 | trainer( | 953 | trainer( |
1070 | strategy=lora_strategy, | 954 | strategy=lora_strategy, |