summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py8
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py2
-rw-r--r--train_lora.py303
-rw-r--r--train_ti.py6
-rw-r--r--training/strategy/lora.py37
5 files changed, 305 insertions, 51 deletions
diff --git a/infer.py b/infer.py
index 93848d7..8fdf63d 100644
--- a/infer.py
+++ b/infer.py
@@ -67,8 +67,8 @@ default_cmds = {
67 "batch_num": 1, 67 "batch_num": 1,
68 "steps": 30, 68 "steps": 30,
69 "guidance_scale": 7.0, 69 "guidance_scale": 7.0,
70 "sag_scale": 0.75, 70 "sag_scale": 0,
71 "lora_scale": 0.5, 71 "brightness_offset": 0,
72 "seed": None, 72 "seed": None,
73 "config": None, 73 "config": None,
74} 74}
@@ -192,7 +192,7 @@ def create_cmd_parser():
192 type=float, 192 type=float,
193 ) 193 )
194 parser.add_argument( 194 parser.add_argument(
195 "--lora_scale", 195 "--brightness_offset",
196 type=float, 196 type=float,
197 ) 197 )
198 parser.add_argument( 198 parser.add_argument(
@@ -392,7 +392,7 @@ def generate(output_dir: Path, pipeline, args):
392 generator=generator, 392 generator=generator,
393 image=init_image, 393 image=init_image,
394 strength=args.image_noise, 394 strength=args.image_noise,
395 # cross_attention_kwargs={"scale": args.lora_scale}, 395 brightness_offset=args.brightness_offset,
396 ).images 396 ).images
397 397
398 for j, image in enumerate(images): 398 for j, image in enumerate(images):
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 127ca50..cfc3208 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -403,7 +403,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
403 width: Optional[int] = None, 403 width: Optional[int] = None,
404 num_inference_steps: int = 50, 404 num_inference_steps: int = 50,
405 guidance_scale: float = 7.5, 405 guidance_scale: float = 7.5,
406 sag_scale: float = 0.75, 406 sag_scale: float = 0.0,
407 eta: float = 0.0, 407 eta: float = 0.0,
408 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 408 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
409 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 409 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
diff --git a/train_lora.py b/train_lora.py
index 1ca56d9..39bf455 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -1,6 +1,7 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import itertools
4from pathlib import Path 5from pathlib import Path
5from functools import partial 6from functools import partial
6import math 7import math
@@ -17,9 +18,10 @@ import transformers
17 18
18from util.files import load_config, load_embeddings_from_dir 19from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 20from data.csv import VlpnDataModule, keyword_filter
20from training.functional import train, get_models 21from training.functional import train, add_placeholder_tokens, get_models
21from training.lr import plot_metrics 22from training.lr import plot_metrics
22from training.strategy.lora import lora_strategy 23from training.strategy.lora import lora_strategy
24from training.strategy.ti import textual_inversion_strategy
23from training.optimization import get_scheduler 25from training.optimization import get_scheduler
24from training.util import save_args 26from training.util import save_args
25 27
@@ -81,6 +83,43 @@ def parse_args():
81 help="The name of the current project.", 83 help="The name of the current project.",
82 ) 84 )
83 parser.add_argument( 85 parser.add_argument(
86 "--placeholder_tokens",
87 type=str,
88 nargs='*',
89 help="A token to use as a placeholder for the concept.",
90 )
91 parser.add_argument(
92 "--initializer_tokens",
93 type=str,
94 nargs='*',
95 help="A token to use as initializer word."
96 )
97 parser.add_argument(
98 "--initializer_noise",
99 type=float,
100 default=0,
101 help="Noise to apply to the initializer word"
102 )
103 parser.add_argument(
104 "--alias_tokens",
105 type=str,
106 nargs='*',
107 default=[],
108 help="Tokens to create an alias for."
109 )
110 parser.add_argument(
111 "--inverted_initializer_tokens",
112 type=str,
113 nargs='*',
114 help="A token to use as initializer word."
115 )
116 parser.add_argument(
117 "--num_vectors",
118 type=int,
119 nargs='*',
120 help="Number of vectors per embedding."
121 )
122 parser.add_argument(
84 "--exclude_collections", 123 "--exclude_collections",
85 type=str, 124 type=str,
86 nargs='*', 125 nargs='*',
@@ -187,6 +226,16 @@ def parse_args():
187 default=2000 226 default=2000
188 ) 227 )
189 parser.add_argument( 228 parser.add_argument(
229 "--num_pti_epochs",
230 type=int,
231 default=None
232 )
233 parser.add_argument(
234 "--num_pti_steps",
235 type=int,
236 default=500
237 )
238 parser.add_argument(
190 "--gradient_accumulation_steps", 239 "--gradient_accumulation_steps",
191 type=int, 240 type=int,
192 default=1, 241 default=1,
@@ -258,6 +307,12 @@ def parse_args():
258 help="Initial learning rate (after the potential warmup period) to use.", 307 help="Initial learning rate (after the potential warmup period) to use.",
259 ) 308 )
260 parser.add_argument( 309 parser.add_argument(
310 "--learning_rate_pti",
311 type=float,
312 default=1e-4,
313 help="Initial learning rate (after the potential warmup period) to use.",
314 )
315 parser.add_argument(
261 "--scale_lr", 316 "--scale_lr",
262 action="store_true", 317 action="store_true",
263 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 318 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
@@ -433,6 +488,23 @@ def parse_args():
433 help="The weight of prior preservation loss." 488 help="The weight of prior preservation loss."
434 ) 489 )
435 parser.add_argument( 490 parser.add_argument(
491 "--use_emb_decay",
492 action="store_true",
493 help="Whether to use embedding decay."
494 )
495 parser.add_argument(
496 "--emb_decay_target",
497 default=0.4,
498 type=float,
499 help="Embedding decay target."
500 )
501 parser.add_argument(
502 "--emb_decay",
503 default=1e+2,
504 type=float,
505 help="Embedding decay factor."
506 )
507 parser.add_argument(
436 "--max_grad_norm", 508 "--max_grad_norm",
437 default=1.0, 509 default=1.0,
438 type=float, 510 type=float,
@@ -464,6 +536,40 @@ def parse_args():
464 if args.project is None: 536 if args.project is None:
465 raise ValueError("You must specify --project") 537 raise ValueError("You must specify --project")
466 538
539 if isinstance(args.placeholder_tokens, str):
540 args.placeholder_tokens = [args.placeholder_tokens]
541
542 if isinstance(args.initializer_tokens, str):
543 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
544
545 if len(args.initializer_tokens) == 0:
546 raise ValueError("You must specify --initializer_tokens")
547
548 if len(args.placeholder_tokens) == 0:
549 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
550
551 if len(args.placeholder_tokens) != len(args.initializer_tokens):
552 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
553
554 if isinstance(args.inverted_initializer_tokens, str):
555 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens)
556
557 if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0:
558 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
559 args.initializer_tokens += args.inverted_initializer_tokens
560
561 if isinstance(args.num_vectors, int):
562 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
563
564 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors):
565 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
566
567 if args.alias_tokens is None:
568 args.alias_tokens = []
569
570 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
571 raise ValueError("--alias_tokens must be a list with an even number of items")
572
467 if isinstance(args.collection, str): 573 if isinstance(args.collection, str):
468 args.collection = [args.collection] 574 args.collection = [args.collection]
469 575
@@ -544,6 +650,19 @@ def main():
544 if args.gradient_checkpointing: 650 if args.gradient_checkpointing:
545 unet.enable_gradient_checkpointing() 651 unet.enable_gradient_checkpointing()
546 652
653 if len(args.alias_tokens) != 0:
654 alias_placeholder_tokens = args.alias_tokens[::2]
655 alias_initializer_tokens = args.alias_tokens[1::2]
656
657 added_tokens, added_ids = add_placeholder_tokens(
658 tokenizer=tokenizer,
659 embeddings=embeddings,
660 placeholder_tokens=alias_placeholder_tokens,
661 initializer_tokens=alias_initializer_tokens
662 )
663 embeddings.persist()
664 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}")
665
547 if args.embeddings_dir is not None: 666 if args.embeddings_dir is not None:
548 embeddings_dir = Path(args.embeddings_dir) 667 embeddings_dir = Path(args.embeddings_dir)
549 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 668 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
@@ -552,6 +671,19 @@ def main():
552 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 671 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
553 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 672 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
554 673
674 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
675 tokenizer=tokenizer,
676 embeddings=embeddings,
677 placeholder_tokens=args.placeholder_tokens,
678 initializer_tokens=args.initializer_tokens,
679 num_vectors=args.num_vectors,
680 initializer_noise=args.initializer_noise,
681 )
682 stats = list(zip(
683 args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids
684 ))
685 print(f"Training embeddings: {stats}")
686
555 if args.scale_lr: 687 if args.scale_lr:
556 args.learning_rate_unet = ( 688 args.learning_rate_unet = (
557 args.learning_rate_unet * args.gradient_accumulation_steps * 689 args.learning_rate_unet * args.gradient_accumulation_steps *
@@ -561,10 +693,15 @@ def main():
561 args.learning_rate_text * args.gradient_accumulation_steps * 693 args.learning_rate_text * args.gradient_accumulation_steps *
562 args.train_batch_size * accelerator.num_processes 694 args.train_batch_size * accelerator.num_processes
563 ) 695 )
696 args.learning_rate_pti = (
697 args.learning_rate_pti * args.gradient_accumulation_steps *
698 args.train_batch_size * accelerator.num_processes
699 )
564 700
565 if args.find_lr: 701 if args.find_lr:
566 args.learning_rate_unet = 1e-6 702 args.learning_rate_unet = 1e-6
567 args.learning_rate_text = 1e-6 703 args.learning_rate_text = 1e-6
704 args.learning_rate_pti = 1e-6
568 args.lr_scheduler = "exponential_growth" 705 args.lr_scheduler = "exponential_growth"
569 706
570 if args.optimizer == 'adam8bit': 707 if args.optimizer == 'adam8bit':
@@ -663,18 +800,25 @@ def main():
663 accelerator=accelerator, 800 accelerator=accelerator,
664 unet=unet, 801 unet=unet,
665 text_encoder=text_encoder, 802 text_encoder=text_encoder,
803 tokenizer=tokenizer,
666 vae=vae, 804 vae=vae,
667 noise_scheduler=noise_scheduler, 805 noise_scheduler=noise_scheduler,
668 dtype=weight_dtype, 806 dtype=weight_dtype,
807 seed=args.seed,
669 guidance_scale=args.guidance_scale, 808 guidance_scale=args.guidance_scale,
670 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 809 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
671 no_val=args.valid_set_size == 0, 810 no_val=args.valid_set_size == 0,
811 gradient_accumulation_steps=args.gradient_accumulation_steps,
812 offset_noise_strength=args.offset_noise_strength,
813 sample_scheduler=sample_scheduler,
814 sample_batch_size=args.sample_batch_size,
815 sample_num_batches=args.sample_batches,
816 sample_num_steps=args.sample_steps,
817 sample_image_size=args.sample_image_size,
672 ) 818 )
673 819
674 checkpoint_output_dir = output_dir / "model" 820 create_datamodule = partial(
675 sample_output_dir = output_dir/"samples" 821 VlpnDataModule,
676
677 datamodule = VlpnDataModule(
678 data_file=args.train_data_file, 822 data_file=args.train_data_file,
679 batch_size=args.train_batch_size, 823 batch_size=args.train_batch_size,
680 tokenizer=tokenizer, 824 tokenizer=tokenizer,
@@ -693,71 +837,146 @@ def main():
693 train_set_pad=args.train_set_pad, 837 train_set_pad=args.train_set_pad,
694 valid_set_pad=args.valid_set_pad, 838 valid_set_pad=args.valid_set_pad,
695 seed=args.seed, 839 seed=args.seed,
840 dtype=weight_dtype,
841 )
842
843 create_lr_scheduler = partial(
844 get_scheduler,
845 args.lr_scheduler,
846 gradient_accumulation_steps=args.gradient_accumulation_steps,
847 min_lr=args.lr_min_lr,
848 warmup_func=args.lr_warmup_func,
849 annealing_func=args.lr_annealing_func,
850 warmup_exp=args.lr_warmup_exp,
851 annealing_exp=args.lr_annealing_exp,
852 cycles=args.lr_cycles,
853 end_lr=1e2,
854 warmup_epochs=args.lr_warmup_epochs,
855 mid_point=args.lr_mid_point,
856 )
857
858 # PTI
859 # --------------------------------------------------------------------------------
860
861 pti_output_dir = output_dir / "pti"
862 pti_checkpoint_output_dir = pti_output_dir / "model"
863 pti_sample_output_dir = pti_output_dir / "samples"
864
865 pti_datamodule = create_datamodule(
866 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
867 )
868 pti_datamodule.setup()
869
870 num_pti_epochs = args.num_pti_epochs
871 pti_sample_frequency = args.sample_frequency
872 if num_pti_epochs is None:
873 num_pti_epochs = math.ceil(
874 args.num_pti_steps / len(pti_datamodule.train_dataset)
875 ) * args.gradient_accumulation_steps
876 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps))
877
878 pti_optimizer = create_optimizer(
879 [
880 {
881 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
882 "lr": args.learning_rate_pti,
883 "weight_decay": 0,
884 },
885 ]
886 )
887
888 pti_lr_scheduler = create_lr_scheduler(
889 optimizer=pti_optimizer,
890 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
891 train_epochs=num_pti_epochs,
892 )
893
894 metrics = trainer(
895 strategy=textual_inversion_strategy,
896 project="ti",
897 train_dataloader=pti_datamodule.train_dataloader,
898 val_dataloader=pti_datamodule.val_dataloader,
899 optimizer=pti_optimizer,
900 lr_scheduler=pti_lr_scheduler,
901 num_train_epochs=num_pti_epochs,
902 # --
903 sample_output_dir=pti_sample_output_dir,
904 checkpoint_output_dir=pti_checkpoint_output_dir,
905 sample_frequency=pti_sample_frequency,
906 placeholder_tokens=args.placeholder_tokens,
907 placeholder_token_ids=placeholder_token_ids,
908 use_emb_decay=args.use_emb_decay,
909 emb_decay_target=args.emb_decay_target,
910 emb_decay=args.emb_decay,
911 )
912
913 plot_metrics(metrics, output_dir/"lr.png")
914
915 # LORA
916 # --------------------------------------------------------------------------------
917
918 lora_output_dir = output_dir / "pti"
919 lora_checkpoint_output_dir = lora_output_dir / "model"
920 lora_sample_output_dir = lora_output_dir / "samples"
921
922 lora_datamodule = create_datamodule(
696 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 923 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
697 dtype=weight_dtype
698 ) 924 )
699 datamodule.setup() 925 lora_datamodule.setup()
700 926
701 num_train_epochs = args.num_train_epochs 927 num_train_epochs = args.num_train_epochs
702 sample_frequency = args.sample_frequency 928 lora_sample_frequency = args.sample_frequency
703 if num_train_epochs is None: 929 if num_train_epochs is None:
704 num_train_epochs = math.ceil( 930 num_train_epochs = math.ceil(
705 args.num_train_steps / len(datamodule.train_dataset) 931 args.num_train_steps / len(lora_datamodule.train_dataset)
706 ) * args.gradient_accumulation_steps 932 ) * args.gradient_accumulation_steps
707 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 933 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps))
708 934
709 optimizer = create_optimizer( 935 lora_optimizer = create_optimizer(
710 [ 936 [
711 { 937 {
712 "params": unet.parameters(), 938 "params": unet.parameters(),
713 "lr": args.learning_rate_unet, 939 "lr": args.learning_rate_unet,
714 }, 940 },
715 { 941 {
716 "params": text_encoder.parameters(), 942 "params": itertools.chain(
943 text_encoder.text_model.encoder.parameters(),
944 text_encoder.text_model.final_layer_norm.parameters(),
945 ),
946 "lr": args.learning_rate_text,
947 },
948 {
949 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
717 "lr": args.learning_rate_text, 950 "lr": args.learning_rate_text,
951 "weight_decay": 0,
718 }, 952 },
719 ] 953 ]
720 ) 954 )
721 955
722 lr_scheduler = get_scheduler( 956 lora_lr_scheduler = create_lr_scheduler(
723 args.lr_scheduler, 957 optimizer=lora_optimizer,
724 optimizer=optimizer, 958 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader),
725 num_training_steps_per_epoch=len(datamodule.train_dataloader),
726 gradient_accumulation_steps=args.gradient_accumulation_steps,
727 min_lr=args.lr_min_lr,
728 warmup_func=args.lr_warmup_func,
729 annealing_func=args.lr_annealing_func,
730 warmup_exp=args.lr_warmup_exp,
731 annealing_exp=args.lr_annealing_exp,
732 cycles=args.lr_cycles,
733 end_lr=1e2,
734 train_epochs=num_train_epochs, 959 train_epochs=num_train_epochs,
735 warmup_epochs=args.lr_warmup_epochs,
736 mid_point=args.lr_mid_point,
737 ) 960 )
738 961
739 metrics = trainer( 962 metrics = trainer(
740 strategy=lora_strategy, 963 strategy=lora_strategy,
741 project="lora", 964 project="lora",
742 train_dataloader=datamodule.train_dataloader, 965 train_dataloader=lora_datamodule.train_dataloader,
743 val_dataloader=datamodule.val_dataloader, 966 val_dataloader=lora_datamodule.val_dataloader,
744 seed=args.seed, 967 optimizer=lora_optimizer,
745 optimizer=optimizer, 968 lr_scheduler=lora_lr_scheduler,
746 lr_scheduler=lr_scheduler,
747 num_train_epochs=num_train_epochs, 969 num_train_epochs=num_train_epochs,
748 gradient_accumulation_steps=args.gradient_accumulation_steps,
749 sample_frequency=sample_frequency,
750 offset_noise_strength=args.offset_noise_strength,
751 # -- 970 # --
752 tokenizer=tokenizer, 971 sample_output_dir=lora_sample_output_dir,
753 sample_scheduler=sample_scheduler, 972 checkpoint_output_dir=lora_checkpoint_output_dir,
754 sample_output_dir=sample_output_dir, 973 sample_frequency=lora_sample_frequency,
755 checkpoint_output_dir=checkpoint_output_dir, 974 placeholder_tokens=args.placeholder_tokens,
975 placeholder_token_ids=placeholder_token_ids,
976 use_emb_decay=args.use_emb_decay,
977 emb_decay_target=args.emb_decay_target,
978 emb_decay=args.emb_decay,
756 max_grad_norm=args.max_grad_norm, 979 max_grad_norm=args.max_grad_norm,
757 sample_batch_size=args.sample_batch_size,
758 sample_num_batches=args.sample_batches,
759 sample_num_steps=args.sample_steps,
760 sample_image_size=args.sample_image_size,
761 ) 980 )
762 981
763 plot_metrics(metrics, output_dir/"lr.png") 982 plot_metrics(metrics, output_dir/"lr.png")
diff --git a/train_ti.py b/train_ti.py
index fc0d68c..344b412 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -750,6 +750,7 @@ def main():
750 accelerator=accelerator, 750 accelerator=accelerator,
751 unet=unet, 751 unet=unet,
752 text_encoder=text_encoder, 752 text_encoder=text_encoder,
753 tokenizer=tokenizer,
753 vae=vae, 754 vae=vae,
754 noise_scheduler=noise_scheduler, 755 noise_scheduler=noise_scheduler,
755 dtype=weight_dtype, 756 dtype=weight_dtype,
@@ -764,8 +765,6 @@ def main():
764 global_step_offset=global_step_offset, 765 global_step_offset=global_step_offset,
765 offset_noise_strength=args.offset_noise_strength, 766 offset_noise_strength=args.offset_noise_strength,
766 # -- 767 # --
767 tokenizer=tokenizer,
768 sample_scheduler=sample_scheduler,
769 checkpoint_output_dir=checkpoint_output_dir, 768 checkpoint_output_dir=checkpoint_output_dir,
770 use_emb_decay=args.use_emb_decay, 769 use_emb_decay=args.use_emb_decay,
771 emb_decay_target=args.emb_decay_target, 770 emb_decay_target=args.emb_decay_target,
@@ -774,6 +773,7 @@ def main():
774 ema_inv_gamma=args.ema_inv_gamma, 773 ema_inv_gamma=args.ema_inv_gamma,
775 ema_power=args.ema_power, 774 ema_power=args.ema_power,
776 ema_max_decay=args.ema_max_decay, 775 ema_max_decay=args.ema_max_decay,
776 sample_scheduler=sample_scheduler,
777 sample_batch_size=args.sample_batch_size, 777 sample_batch_size=args.sample_batch_size,
778 sample_num_batches=args.sample_batches, 778 sample_num_batches=args.sample_batches,
779 sample_num_steps=args.sample_steps, 779 sample_num_steps=args.sample_steps,
@@ -863,9 +863,9 @@ def main():
863 optimizer=optimizer, 863 optimizer=optimizer,
864 lr_scheduler=lr_scheduler, 864 lr_scheduler=lr_scheduler,
865 num_train_epochs=num_train_epochs, 865 num_train_epochs=num_train_epochs,
866 sample_frequency=sample_frequency,
867 # -- 866 # --
868 sample_output_dir=sample_output_dir, 867 sample_output_dir=sample_output_dir,
868 sample_frequency=sample_frequency,
869 placeholder_tokens=placeholder_tokens, 869 placeholder_tokens=placeholder_tokens,
870 placeholder_token_ids=placeholder_token_ids, 870 placeholder_token_ids=placeholder_token_ids,
871 ) 871 )
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 209785a..d51a2f3 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -14,6 +14,8 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
14from peft import get_peft_model_state_dict 14from peft import get_peft_model_state_dict
15from safetensors.torch import save_file 15from safetensors.torch import save_file
16 16
17from slugify import slugify
18
17from models.clip.tokenizer import MultiCLIPTokenizer 19from models.clip.tokenizer import MultiCLIPTokenizer
18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 20from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 21
@@ -30,6 +32,11 @@ def lora_strategy_callbacks(
30 sample_output_dir: Path, 32 sample_output_dir: Path,
31 checkpoint_output_dir: Path, 33 checkpoint_output_dir: Path,
32 seed: int, 34 seed: int,
35 placeholder_tokens: list[str],
36 placeholder_token_ids: list[list[int]],
37 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4,
39 emb_decay: float = 1e-2,
33 max_grad_norm: float = 1.0, 40 max_grad_norm: float = 1.0,
34 sample_batch_size: int = 1, 41 sample_batch_size: int = 1,
35 sample_num_batches: int = 1, 42 sample_num_batches: int = 1,
@@ -77,6 +84,22 @@ def lora_strategy_callbacks(
77 max_grad_norm 84 max_grad_norm
78 ) 85 )
79 86
87 if use_emb_decay:
88 return torch.stack([
89 p
90 for p in text_encoder.text_model.embeddings.token_override_embedding.params
91 if p.grad is not None
92 ])
93
94 @torch.no_grad()
95 def on_after_optimize(w, lr: float):
96 if use_emb_decay:
97 lambda_ = emb_decay * lr
98
99 if lambda_ != 0:
100 norm = w[:, :].norm(dim=-1, keepdim=True)
101 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
102
80 @torch.no_grad() 103 @torch.no_grad()
81 def on_checkpoint(step, postfix): 104 def on_checkpoint(step, postfix):
82 if postfix != "end": 105 if postfix != "end":
@@ -87,6 +110,12 @@ def lora_strategy_callbacks(
87 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 110 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
88 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 111 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
89 112
113 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
114 text_encoder_.text_model.embeddings.save_embed(
115 ids,
116 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
117 )
118
90 lora_config = {} 119 lora_config = {}
91 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) 120 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_))
92 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) 121 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
@@ -126,6 +155,7 @@ def lora_strategy_callbacks(
126 on_train=on_train, 155 on_train=on_train,
127 on_eval=on_eval, 156 on_eval=on_eval,
128 on_before_optimize=on_before_optimize, 157 on_before_optimize=on_before_optimize,
158 on_after_optimize=on_after_optimize,
129 on_checkpoint=on_checkpoint, 159 on_checkpoint=on_checkpoint,
130 on_sample=on_sample, 160 on_sample=on_sample,
131 ) 161 )
@@ -141,7 +171,12 @@ def lora_prepare(
141 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 171 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
142 **kwargs 172 **kwargs
143): 173):
144 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) 174 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
175 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
176
177 text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True)
178
179 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
145 180
146 181
147lora_strategy = TrainingStrategy( 182lora_strategy = TrainingStrategy(