summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py140
1 files changed, 12 insertions, 128 deletions
diff --git a/train_lora.py b/train_lora.py
index 8dbe45b..6e21634 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -159,12 +159,6 @@ def parse_args():
159 help="Tag dropout probability.", 159 help="Tag dropout probability.",
160 ) 160 )
161 parser.add_argument( 161 parser.add_argument(
162 "--pti_tag_dropout",
163 type=float,
164 default=0,
165 help="Tag dropout probability.",
166 )
167 parser.add_argument(
168 "--no_tag_shuffle", 162 "--no_tag_shuffle",
169 action="store_true", 163 action="store_true",
170 help="Shuffle tags.", 164 help="Shuffle tags.",
@@ -236,28 +230,12 @@ def parse_args():
236 default=2000 230 default=2000
237 ) 231 )
238 parser.add_argument( 232 parser.add_argument(
239 "--num_pti_epochs",
240 type=int,
241 default=None
242 )
243 parser.add_argument(
244 "--num_pti_steps",
245 type=int,
246 default=500
247 )
248 parser.add_argument(
249 "--gradient_accumulation_steps", 233 "--gradient_accumulation_steps",
250 type=int, 234 type=int,
251 default=1, 235 default=1,
252 help="Number of updates steps to accumulate before performing a backward/update pass.", 236 help="Number of updates steps to accumulate before performing a backward/update pass.",
253 ) 237 )
254 parser.add_argument( 238 parser.add_argument(
255 "--pti_gradient_accumulation_steps",
256 type=int,
257 default=1,
258 help="Number of updates steps to accumulate before performing a backward/update pass.",
259 )
260 parser.add_argument(
261 "--lora_r", 239 "--lora_r",
262 type=int, 240 type=int,
263 default=8, 241 default=8,
@@ -323,12 +301,6 @@ def parse_args():
323 help="Initial learning rate (after the potential warmup period) to use.", 301 help="Initial learning rate (after the potential warmup period) to use.",
324 ) 302 )
325 parser.add_argument( 303 parser.add_argument(
326 "--learning_rate_pti",
327 type=float,
328 default=1e-4,
329 help="Initial learning rate (after the potential warmup period) to use.",
330 )
331 parser.add_argument(
332 "--learning_rate_emb", 304 "--learning_rate_emb",
333 type=float, 305 type=float,
334 default=1e-5, 306 default=1e-5,
@@ -467,12 +439,6 @@ def parse_args():
467 help="How often to save a checkpoint and sample image", 439 help="How often to save a checkpoint and sample image",
468 ) 440 )
469 parser.add_argument( 441 parser.add_argument(
470 "--pti_sample_frequency",
471 type=int,
472 default=1,
473 help="How often to save a checkpoint and sample image",
474 )
475 parser.add_argument(
476 "--sample_image_size", 442 "--sample_image_size",
477 type=int, 443 type=int,
478 default=768, 444 default=768,
@@ -509,12 +475,6 @@ def parse_args():
509 help="Batch size (per device) for the training dataloader." 475 help="Batch size (per device) for the training dataloader."
510 ) 476 )
511 parser.add_argument( 477 parser.add_argument(
512 "--pti_batch_size",
513 type=int,
514 default=1,
515 help="Batch size (per device) for the training dataloader."
516 )
517 parser.add_argument(
518 "--sample_steps", 478 "--sample_steps",
519 type=int, 479 type=int,
520 default=10, 480 default=10,
@@ -527,6 +487,12 @@ def parse_args():
527 help="The weight of prior preservation loss." 487 help="The weight of prior preservation loss."
528 ) 488 )
529 parser.add_argument( 489 parser.add_argument(
490 "--emb_dropout",
491 type=float,
492 default=0,
493 help="Embedding dropout probability.",
494 )
495 parser.add_argument(
530 "--use_emb_decay", 496 "--use_emb_decay",
531 action="store_true", 497 action="store_true",
532 help="Whether to use embedding decay." 498 help="Whether to use embedding decay."
@@ -674,7 +640,7 @@ def main():
674 save_args(output_dir, args) 640 save_args(output_dir, args)
675 641
676 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 642 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
677 args.pretrained_model_name_or_path) 643 args.pretrained_model_name_or_path, args.emb_dropout)
678 644
679 unet_config = LoraConfig( 645 unet_config = LoraConfig(
680 r=args.lora_r, 646 r=args.lora_r,
@@ -720,6 +686,7 @@ def main():
720 raise ValueError("--embeddings_dir must point to an existing directory") 686 raise ValueError("--embeddings_dir must point to an existing directory")
721 687
722 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 688 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
689 embeddings.persist()
723 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 690 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
724 691
725 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 692 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
@@ -744,19 +711,14 @@ def main():
744 args.learning_rate_text * args.gradient_accumulation_steps * 711 args.learning_rate_text * args.gradient_accumulation_steps *
745 args.train_batch_size * accelerator.num_processes 712 args.train_batch_size * accelerator.num_processes
746 ) 713 )
747 args.learning_rate_pti = (
748 args.learning_rate_pti * args.pti_gradient_accumulation_steps *
749 args.pti_batch_size * accelerator.num_processes
750 )
751 args.learning_rate_emb = ( 714 args.learning_rate_emb = (
752 args.learning_rate_emb * args.pti_gradient_accumulation_steps * 715 args.learning_rate_emb * args.gradient_accumulation_steps *
753 args.pti_batch_size * accelerator.num_processes 716 args.train_batch_size * accelerator.num_processes
754 ) 717 )
755 718
756 if args.find_lr: 719 if args.find_lr:
757 args.learning_rate_unet = 1e-6 720 args.learning_rate_unet = 1e-6
758 args.learning_rate_text = 1e-6 721 args.learning_rate_text = 1e-6
759 args.learning_rate_pti = 1e-6
760 args.learning_rate_emb = 1e-6 722 args.learning_rate_emb = 1e-6
761 args.lr_scheduler = "exponential_growth" 723 args.lr_scheduler = "exponential_growth"
762 724
@@ -817,7 +779,6 @@ def main():
817 args.lr_min_lr = args.learning_rate_unet 779 args.lr_min_lr = args.learning_rate_unet
818 args.learning_rate_unet = None 780 args.learning_rate_unet = None
819 args.learning_rate_text = None 781 args.learning_rate_text = None
820 args.learning_rate_pti = None
821 args.learning_rate_emb = None 782 args.learning_rate_emb = None
822 elif args.optimizer == 'dadam': 783 elif args.optimizer == 'dadam':
823 try: 784 try:
@@ -836,7 +797,6 @@ def main():
836 797
837 args.learning_rate_unet = 1.0 798 args.learning_rate_unet = 1.0
838 args.learning_rate_text = 1.0 799 args.learning_rate_text = 1.0
839 args.learning_rate_pti = 1.0
840 args.learning_rate_emb = 1.0 800 args.learning_rate_emb = 1.0
841 elif args.optimizer == 'dadan': 801 elif args.optimizer == 'dadan':
842 try: 802 try:
@@ -853,7 +813,6 @@ def main():
853 813
854 args.learning_rate_unet = 1.0 814 args.learning_rate_unet = 1.0
855 args.learning_rate_text = 1.0 815 args.learning_rate_text = 1.0
856 args.learning_rate_pti = 1.0
857 args.learning_rate_emb = 1.0 816 args.learning_rate_emb = 1.0
858 else: 817 else:
859 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 818 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
@@ -920,80 +879,6 @@ def main():
920 mid_point=args.lr_mid_point, 879 mid_point=args.lr_mid_point,
921 ) 880 )
922 881
923 # PTI
924 # --------------------------------------------------------------------------------
925
926 if len(args.placeholder_tokens) != 0:
927 pti_datamodule = create_datamodule(
928 batch_size=args.pti_batch_size,
929 dropout=args.pti_tag_dropout,
930 filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections),
931 )
932 pti_datamodule.setup()
933
934 num_pti_epochs = args.num_pti_epochs
935 pti_sample_frequency = args.pti_sample_frequency
936 if num_pti_epochs is None:
937 num_pti_epochs = math.ceil(
938 args.num_pti_steps / len(pti_datamodule.train_dataset)
939 ) * args.pti_gradient_accumulation_steps
940 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps))
941
942 if num_pti_epochs > 0:
943 pti_optimizer = create_optimizer(
944 [
945 {
946 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
947 "lr": args.learning_rate_pti,
948 "weight_decay": 0,
949 },
950 ]
951 )
952
953 pti_lr_scheduler = create_lr_scheduler(
954 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
955 optimizer=pti_optimizer,
956 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
957 train_epochs=num_pti_epochs,
958 )
959
960 continue_training = True
961 training_iter = 1
962
963 while continue_training:
964 print("")
965 print(f"============ PTI cycle {training_iter} ============")
966 print("")
967
968 pti_project = f"pti_{training_iter}"
969 pti_output_dir = output_dir / pti_project
970 pti_checkpoint_output_dir = pti_output_dir / "model"
971 pti_sample_output_dir = pti_output_dir / "samples"
972
973 trainer(
974 strategy=lora_strategy,
975 pti_mode=True,
976 project=pti_project,
977 train_dataloader=pti_datamodule.train_dataloader,
978 val_dataloader=pti_datamodule.val_dataloader,
979 optimizer=pti_optimizer,
980 lr_scheduler=pti_lr_scheduler,
981 num_train_epochs=num_pti_epochs,
982 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
983 # --
984 group_labels=["emb"],
985 sample_output_dir=pti_sample_output_dir,
986 checkpoint_output_dir=pti_checkpoint_output_dir,
987 sample_frequency=pti_sample_frequency,
988 )
989
990 response = input("Run another cycle? [y/n] ")
991 continue_training = response.lower().strip() != "n"
992 training_iter += 1
993
994 if not args.train_emb:
995 embeddings.persist()
996
997 # LORA 882 # LORA
998 # -------------------------------------------------------------------------------- 883 # --------------------------------------------------------------------------------
999 884
@@ -1062,9 +947,8 @@ def main():
1062 print("") 947 print("")
1063 948
1064 lora_project = f"lora_{training_iter}" 949 lora_project = f"lora_{training_iter}"
1065 lora_output_dir = output_dir / lora_project 950 lora_checkpoint_output_dir = output_dir / lora_project / "model"
1066 lora_checkpoint_output_dir = lora_output_dir / "model" 951 lora_sample_output_dir = output_dir / lora_project / "samples"
1067 lora_sample_output_dir = lora_output_dir / "samples"
1068 952
1069 trainer( 953 trainer(
1070 strategy=lora_strategy, 954 strategy=lora_strategy,