summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py66
-rw-r--r--train_ti.py68
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py2
6 files changed, 89 insertions, 55 deletions
diff --git a/train_lora.py b/train_lora.py
index e81742a..4bbc64e 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -199,6 +199,11 @@ def parse_args():
199 help="The embeddings directory where Textual Inversion embeddings are stored.", 199 help="The embeddings directory where Textual Inversion embeddings are stored.",
200 ) 200 )
201 parser.add_argument( 201 parser.add_argument(
202 "--train_dir_embeddings",
203 action="store_true",
204 help="Train embeddings loaded from embeddings directory.",
205 )
206 parser.add_argument(
202 "--collection", 207 "--collection",
203 type=str, 208 type=str,
204 nargs='*', 209 nargs='*',
@@ -440,6 +445,12 @@ def parse_args():
440 help="How often to save a checkpoint and sample image", 445 help="How often to save a checkpoint and sample image",
441 ) 446 )
442 parser.add_argument( 447 parser.add_argument(
448 "--sample_num",
449 type=int,
450 default=None,
451 help="How often to save a checkpoint and sample image (in number of samples)",
452 )
453 parser.add_argument(
443 "--sample_image_size", 454 "--sample_image_size",
444 type=int, 455 type=int,
445 default=768, 456 default=768,
@@ -681,27 +692,36 @@ def main():
681 embeddings.persist() 692 embeddings.persist()
682 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") 693 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}")
683 694
695 placeholder_token_ids = []
696
684 if args.embeddings_dir is not None: 697 if args.embeddings_dir is not None:
685 embeddings_dir = Path(args.embeddings_dir) 698 embeddings_dir = Path(args.embeddings_dir)
686 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 699 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
687 raise ValueError("--embeddings_dir must point to an existing directory") 700 raise ValueError("--embeddings_dir must point to an existing directory")
688 701
689 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 702 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
690 embeddings.persist()
691 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 703 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
692 704
693 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 705 if args.train_dir_embeddings:
694 tokenizer=tokenizer, 706 args.placeholder_tokens = added_tokens
695 embeddings=embeddings, 707 placeholder_token_ids = added_ids
696 placeholder_tokens=args.placeholder_tokens, 708 print("Training embeddings from embeddings dir")
697 initializer_tokens=args.initializer_tokens, 709 else:
698 num_vectors=args.num_vectors, 710 embeddings.persist()
699 initializer_noise=args.initializer_noise, 711
700 ) 712 if not args.train_dir_embeddings:
701 stats = list(zip( 713 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
702 args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids 714 tokenizer=tokenizer,
703 )) 715 embeddings=embeddings,
704 print(f"Training embeddings: {stats}") 716 placeholder_tokens=args.placeholder_tokens,
717 initializer_tokens=args.initializer_tokens,
718 num_vectors=args.num_vectors,
719 initializer_noise=args.initializer_noise,
720 )
721 stats = list(zip(
722 args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids
723 ))
724 print(f"Training embeddings: {stats}")
705 725
706 if args.scale_lr: 726 if args.scale_lr:
707 args.learning_rate_unet = ( 727 args.learning_rate_unet = (
@@ -897,6 +917,8 @@ def main():
897 args.num_train_steps / len(lora_datamodule.train_dataset) 917 args.num_train_steps / len(lora_datamodule.train_dataset)
898 ) * args.gradient_accumulation_steps 918 ) * args.gradient_accumulation_steps
899 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) 919 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps))
920 if args.sample_num is not None:
921 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
900 922
901 params_to_optimize = [] 923 params_to_optimize = []
902 group_labels = [] 924 group_labels = []
@@ -930,15 +952,6 @@ def main():
930 ] 952 ]
931 group_labels += ["unet", "text"] 953 group_labels += ["unet", "text"]
932 954
933 lora_optimizer = create_optimizer(params_to_optimize)
934
935 lora_lr_scheduler = create_lr_scheduler(
936 gradient_accumulation_steps=args.gradient_accumulation_steps,
937 optimizer=lora_optimizer,
938 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader),
939 train_epochs=num_train_epochs,
940 )
941
942 training_iter = 0 955 training_iter = 0
943 956
944 while True: 957 while True:
@@ -952,6 +965,15 @@ def main():
952 print(f"============ LoRA cycle {training_iter} ============") 965 print(f"============ LoRA cycle {training_iter} ============")
953 print("") 966 print("")
954 967
968 lora_optimizer = create_optimizer(params_to_optimize)
969
970 lora_lr_scheduler = create_lr_scheduler(
971 gradient_accumulation_steps=args.gradient_accumulation_steps,
972 optimizer=lora_optimizer,
973 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader),
974 train_epochs=num_train_epochs,
975 )
976
955 lora_project = f"lora_{training_iter}" 977 lora_project = f"lora_{training_iter}"
956 lora_checkpoint_output_dir = output_dir / lora_project / "model" 978 lora_checkpoint_output_dir = output_dir / lora_project / "model"
957 lora_sample_output_dir = output_dir / lora_project / "samples" 979 lora_sample_output_dir = output_dir / lora_project / "samples"
diff --git a/train_ti.py b/train_ti.py
index ebac302..eb08bda 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -152,6 +152,11 @@ def parse_args():
152 help="The embeddings directory where Textual Inversion embeddings are stored.", 152 help="The embeddings directory where Textual Inversion embeddings are stored.",
153 ) 153 )
154 parser.add_argument( 154 parser.add_argument(
155 "--train_dir_embeddings",
156 action="store_true",
157 help="Train embeddings loaded from embeddings directory.",
158 )
159 parser.add_argument(
155 "--collection", 160 "--collection",
156 type=str, 161 type=str,
157 nargs='*', 162 nargs='*',
@@ -404,6 +409,12 @@ def parse_args():
404 help="If checkpoints are saved on maximum accuracy", 409 help="If checkpoints are saved on maximum accuracy",
405 ) 410 )
406 parser.add_argument( 411 parser.add_argument(
412 "--sample_num",
413 type=int,
414 default=None,
415 help="How often to save a checkpoint and sample image (in number of samples)",
416 )
417 parser.add_argument(
407 "--sample_frequency", 418 "--sample_frequency",
408 type=int, 419 type=int,
409 default=1, 420 default=1,
@@ -669,9 +680,14 @@ def main():
669 raise ValueError("--embeddings_dir must point to an existing directory") 680 raise ValueError("--embeddings_dir must point to an existing directory")
670 681
671 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 682 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
672 embeddings.persist()
673 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 683 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
674 684
685 if args.train_dir_embeddings:
686 args.placeholder_tokens = added_tokens
687 print("Training embeddings from embeddings dir")
688 else:
689 embeddings.persist()
690
675 if args.scale_lr: 691 if args.scale_lr:
676 args.learning_rate = ( 692 args.learning_rate = (
677 args.learning_rate * args.gradient_accumulation_steps * 693 args.learning_rate * args.gradient_accumulation_steps *
@@ -852,28 +868,8 @@ def main():
852 args.num_train_steps / len(datamodule.train_dataset) 868 args.num_train_steps / len(datamodule.train_dataset)
853 ) * args.gradient_accumulation_steps 869 ) * args.gradient_accumulation_steps
854 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 870 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
855 871 if args.sample_num is not None:
856 optimizer = create_optimizer( 872 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
857 text_encoder.text_model.embeddings.token_override_embedding.parameters(),
858 lr=args.learning_rate,
859 )
860
861 lr_scheduler = get_scheduler(
862 args.lr_scheduler,
863 optimizer=optimizer,
864 num_training_steps_per_epoch=len(datamodule.train_dataloader),
865 gradient_accumulation_steps=args.gradient_accumulation_steps,
866 min_lr=args.lr_min_lr,
867 warmup_func=args.lr_warmup_func,
868 annealing_func=args.lr_annealing_func,
869 warmup_exp=args.lr_warmup_exp,
870 annealing_exp=args.lr_annealing_exp,
871 cycles=args.lr_cycles,
872 end_lr=1e3,
873 train_epochs=num_train_epochs,
874 warmup_epochs=args.lr_warmup_epochs,
875 mid_point=args.lr_mid_point,
876 )
877 873
878 training_iter = 0 874 training_iter = 0
879 875
@@ -888,6 +884,28 @@ def main():
888 print(f"------------ TI cycle {training_iter} ------------") 884 print(f"------------ TI cycle {training_iter} ------------")
889 print("") 885 print("")
890 886
887 optimizer = create_optimizer(
888 text_encoder.text_model.embeddings.token_override_embedding.parameters(),
889 lr=args.learning_rate,
890 )
891
892 lr_scheduler = get_scheduler(
893 args.lr_scheduler,
894 optimizer=optimizer,
895 num_training_steps_per_epoch=len(datamodule.train_dataloader),
896 gradient_accumulation_steps=args.gradient_accumulation_steps,
897 min_lr=args.lr_min_lr,
898 warmup_func=args.lr_warmup_func,
899 annealing_func=args.lr_annealing_func,
900 warmup_exp=args.lr_warmup_exp,
901 annealing_exp=args.lr_annealing_exp,
902 cycles=args.lr_cycles,
903 end_lr=1e3,
904 train_epochs=num_train_epochs,
905 warmup_epochs=args.lr_warmup_epochs,
906 mid_point=args.lr_mid_point,
907 )
908
891 project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" 909 project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}"
892 sample_output_dir = output_dir / project / "samples" 910 sample_output_dir = output_dir / project / "samples"
893 checkpoint_output_dir = output_dir / project / "checkpoints" 911 checkpoint_output_dir = output_dir / project / "checkpoints"
@@ -908,10 +926,6 @@ def main():
908 placeholder_token_ids=placeholder_token_ids, 926 placeholder_token_ids=placeholder_token_ids,
909 ) 927 )
910 928
911 response = input("Run another cycle? [y/n] ")
912 continue_training = response.lower().strip() != "n"
913 training_iter += 1
914
915 if not args.sequential: 929 if not args.sequential:
916 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 930 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
917 else: 931 else:
diff --git a/training/functional.py b/training/functional.py
index e14aeea..46d25f6 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -644,11 +644,9 @@ def train(
644 min_snr_gamma: int = 5, 644 min_snr_gamma: int = 5,
645 **kwargs, 645 **kwargs,
646): 646):
647 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 647 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare(
648 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) 648 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs)
649 649
650 kwargs.update(extra)
651
652 vae.to(accelerator.device, dtype=dtype) 650 vae.to(accelerator.device, dtype=dtype)
653 vae.requires_grad_(False) 651 vae.requires_grad_(False)
654 vae.eval() 652 vae.eval()
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 695174a..42624cd 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -198,7 +198,7 @@ def dreambooth_prepare(
198 198
199 text_encoder.text_model.embeddings.requires_grad_(False) 199 text_encoder.text_model.embeddings.requires_grad_(False)
200 200
201 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 201 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
202 202
203 203
204dreambooth_strategy = TrainingStrategy( 204dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index ae85401..73ec8f2 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -184,7 +184,7 @@ def lora_prepare(
184 184
185 text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) 185 text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True)
186 186
187 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 187 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
188 188
189 189
190lora_strategy = TrainingStrategy( 190lora_strategy = TrainingStrategy(
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 9cdc1bb..363c3f9 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -207,7 +207,7 @@ def textual_inversion_prepare(
207 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 207 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
208 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) 208 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
209 209
210 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 210 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
211 211
212 212
213textual_inversion_strategy = TrainingStrategy( 213textual_inversion_strategy = TrainingStrategy(