diff options
-rw-r--r-- | infer.py | 8 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
-rw-r--r-- | train_lora.py | 303 | ||||
-rw-r--r-- | train_ti.py | 6 | ||||
-rw-r--r-- | training/strategy/lora.py | 37 |
5 files changed, 305 insertions, 51 deletions
@@ -67,8 +67,8 @@ default_cmds = { | |||
67 | "batch_num": 1, | 67 | "batch_num": 1, |
68 | "steps": 30, | 68 | "steps": 30, |
69 | "guidance_scale": 7.0, | 69 | "guidance_scale": 7.0, |
70 | "sag_scale": 0.75, | 70 | "sag_scale": 0, |
71 | "lora_scale": 0.5, | 71 | "brightness_offset": 0, |
72 | "seed": None, | 72 | "seed": None, |
73 | "config": None, | 73 | "config": None, |
74 | } | 74 | } |
@@ -192,7 +192,7 @@ def create_cmd_parser(): | |||
192 | type=float, | 192 | type=float, |
193 | ) | 193 | ) |
194 | parser.add_argument( | 194 | parser.add_argument( |
195 | "--lora_scale", | 195 | "--brightness_offset", |
196 | type=float, | 196 | type=float, |
197 | ) | 197 | ) |
198 | parser.add_argument( | 198 | parser.add_argument( |
@@ -392,7 +392,7 @@ def generate(output_dir: Path, pipeline, args): | |||
392 | generator=generator, | 392 | generator=generator, |
393 | image=init_image, | 393 | image=init_image, |
394 | strength=args.image_noise, | 394 | strength=args.image_noise, |
395 | # cross_attention_kwargs={"scale": args.lora_scale}, | 395 | brightness_offset=args.brightness_offset, |
396 | ).images | 396 | ).images |
397 | 397 | ||
398 | for j, image in enumerate(images): | 398 | for j, image in enumerate(images): |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 127ca50..cfc3208 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -403,7 +403,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
403 | width: Optional[int] = None, | 403 | width: Optional[int] = None, |
404 | num_inference_steps: int = 50, | 404 | num_inference_steps: int = 50, |
405 | guidance_scale: float = 7.5, | 405 | guidance_scale: float = 7.5, |
406 | sag_scale: float = 0.75, | 406 | sag_scale: float = 0.0, |
407 | eta: float = 0.0, | 407 | eta: float = 0.0, |
408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
diff --git a/train_lora.py b/train_lora.py index 1ca56d9..39bf455 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import itertools | ||
4 | from pathlib import Path | 5 | from pathlib import Path |
5 | from functools import partial | 6 | from functools import partial |
6 | import math | 7 | import math |
@@ -17,9 +18,10 @@ import transformers | |||
17 | 18 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
20 | from training.functional import train, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
21 | from training.lr import plot_metrics | 22 | from training.lr import plot_metrics |
22 | from training.strategy.lora import lora_strategy | 23 | from training.strategy.lora import lora_strategy |
24 | from training.strategy.ti import textual_inversion_strategy | ||
23 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
24 | from training.util import save_args | 26 | from training.util import save_args |
25 | 27 | ||
@@ -81,6 +83,43 @@ def parse_args(): | |||
81 | help="The name of the current project.", | 83 | help="The name of the current project.", |
82 | ) | 84 | ) |
83 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--placeholder_tokens", | ||
87 | type=str, | ||
88 | nargs='*', | ||
89 | help="A token to use as a placeholder for the concept.", | ||
90 | ) | ||
91 | parser.add_argument( | ||
92 | "--initializer_tokens", | ||
93 | type=str, | ||
94 | nargs='*', | ||
95 | help="A token to use as initializer word." | ||
96 | ) | ||
97 | parser.add_argument( | ||
98 | "--initializer_noise", | ||
99 | type=float, | ||
100 | default=0, | ||
101 | help="Noise to apply to the initializer word" | ||
102 | ) | ||
103 | parser.add_argument( | ||
104 | "--alias_tokens", | ||
105 | type=str, | ||
106 | nargs='*', | ||
107 | default=[], | ||
108 | help="Tokens to create an alias for." | ||
109 | ) | ||
110 | parser.add_argument( | ||
111 | "--inverted_initializer_tokens", | ||
112 | type=str, | ||
113 | nargs='*', | ||
114 | help="A token to use as initializer word." | ||
115 | ) | ||
116 | parser.add_argument( | ||
117 | "--num_vectors", | ||
118 | type=int, | ||
119 | nargs='*', | ||
120 | help="Number of vectors per embedding." | ||
121 | ) | ||
122 | parser.add_argument( | ||
84 | "--exclude_collections", | 123 | "--exclude_collections", |
85 | type=str, | 124 | type=str, |
86 | nargs='*', | 125 | nargs='*', |
@@ -187,6 +226,16 @@ def parse_args(): | |||
187 | default=2000 | 226 | default=2000 |
188 | ) | 227 | ) |
189 | parser.add_argument( | 228 | parser.add_argument( |
229 | "--num_pti_epochs", | ||
230 | type=int, | ||
231 | default=None | ||
232 | ) | ||
233 | parser.add_argument( | ||
234 | "--num_pti_steps", | ||
235 | type=int, | ||
236 | default=500 | ||
237 | ) | ||
238 | parser.add_argument( | ||
190 | "--gradient_accumulation_steps", | 239 | "--gradient_accumulation_steps", |
191 | type=int, | 240 | type=int, |
192 | default=1, | 241 | default=1, |
@@ -258,6 +307,12 @@ def parse_args(): | |||
258 | help="Initial learning rate (after the potential warmup period) to use.", | 307 | help="Initial learning rate (after the potential warmup period) to use.", |
259 | ) | 308 | ) |
260 | parser.add_argument( | 309 | parser.add_argument( |
310 | "--learning_rate_pti", | ||
311 | type=float, | ||
312 | default=1e-4, | ||
313 | help="Initial learning rate (after the potential warmup period) to use.", | ||
314 | ) | ||
315 | parser.add_argument( | ||
261 | "--scale_lr", | 316 | "--scale_lr", |
262 | action="store_true", | 317 | action="store_true", |
263 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 318 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
@@ -433,6 +488,23 @@ def parse_args(): | |||
433 | help="The weight of prior preservation loss." | 488 | help="The weight of prior preservation loss." |
434 | ) | 489 | ) |
435 | parser.add_argument( | 490 | parser.add_argument( |
491 | "--use_emb_decay", | ||
492 | action="store_true", | ||
493 | help="Whether to use embedding decay." | ||
494 | ) | ||
495 | parser.add_argument( | ||
496 | "--emb_decay_target", | ||
497 | default=0.4, | ||
498 | type=float, | ||
499 | help="Embedding decay target." | ||
500 | ) | ||
501 | parser.add_argument( | ||
502 | "--emb_decay", | ||
503 | default=1e+2, | ||
504 | type=float, | ||
505 | help="Embedding decay factor." | ||
506 | ) | ||
507 | parser.add_argument( | ||
436 | "--max_grad_norm", | 508 | "--max_grad_norm", |
437 | default=1.0, | 509 | default=1.0, |
438 | type=float, | 510 | type=float, |
@@ -464,6 +536,40 @@ def parse_args(): | |||
464 | if args.project is None: | 536 | if args.project is None: |
465 | raise ValueError("You must specify --project") | 537 | raise ValueError("You must specify --project") |
466 | 538 | ||
539 | if isinstance(args.placeholder_tokens, str): | ||
540 | args.placeholder_tokens = [args.placeholder_tokens] | ||
541 | |||
542 | if isinstance(args.initializer_tokens, str): | ||
543 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | ||
544 | |||
545 | if len(args.initializer_tokens) == 0: | ||
546 | raise ValueError("You must specify --initializer_tokens") | ||
547 | |||
548 | if len(args.placeholder_tokens) == 0: | ||
549 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
550 | |||
551 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
552 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | ||
553 | |||
554 | if isinstance(args.inverted_initializer_tokens, str): | ||
555 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | ||
556 | |||
557 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | ||
558 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
559 | args.initializer_tokens += args.inverted_initializer_tokens | ||
560 | |||
561 | if isinstance(args.num_vectors, int): | ||
562 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | ||
563 | |||
564 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | ||
565 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | ||
566 | |||
567 | if args.alias_tokens is None: | ||
568 | args.alias_tokens = [] | ||
569 | |||
570 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
571 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
572 | |||
467 | if isinstance(args.collection, str): | 573 | if isinstance(args.collection, str): |
468 | args.collection = [args.collection] | 574 | args.collection = [args.collection] |
469 | 575 | ||
@@ -544,6 +650,19 @@ def main(): | |||
544 | if args.gradient_checkpointing: | 650 | if args.gradient_checkpointing: |
545 | unet.enable_gradient_checkpointing() | 651 | unet.enable_gradient_checkpointing() |
546 | 652 | ||
653 | if len(args.alias_tokens) != 0: | ||
654 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
655 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
656 | |||
657 | added_tokens, added_ids = add_placeholder_tokens( | ||
658 | tokenizer=tokenizer, | ||
659 | embeddings=embeddings, | ||
660 | placeholder_tokens=alias_placeholder_tokens, | ||
661 | initializer_tokens=alias_initializer_tokens | ||
662 | ) | ||
663 | embeddings.persist() | ||
664 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | ||
665 | |||
547 | if args.embeddings_dir is not None: | 666 | if args.embeddings_dir is not None: |
548 | embeddings_dir = Path(args.embeddings_dir) | 667 | embeddings_dir = Path(args.embeddings_dir) |
549 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 668 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
@@ -552,6 +671,19 @@ def main(): | |||
552 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 671 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
553 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 672 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
554 | 673 | ||
674 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
675 | tokenizer=tokenizer, | ||
676 | embeddings=embeddings, | ||
677 | placeholder_tokens=args.placeholder_tokens, | ||
678 | initializer_tokens=args.initializer_tokens, | ||
679 | num_vectors=args.num_vectors, | ||
680 | initializer_noise=args.initializer_noise, | ||
681 | ) | ||
682 | stats = list(zip( | ||
683 | args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | ||
684 | )) | ||
685 | print(f"Training embeddings: {stats}") | ||
686 | |||
555 | if args.scale_lr: | 687 | if args.scale_lr: |
556 | args.learning_rate_unet = ( | 688 | args.learning_rate_unet = ( |
557 | args.learning_rate_unet * args.gradient_accumulation_steps * | 689 | args.learning_rate_unet * args.gradient_accumulation_steps * |
@@ -561,10 +693,15 @@ def main(): | |||
561 | args.learning_rate_text * args.gradient_accumulation_steps * | 693 | args.learning_rate_text * args.gradient_accumulation_steps * |
562 | args.train_batch_size * accelerator.num_processes | 694 | args.train_batch_size * accelerator.num_processes |
563 | ) | 695 | ) |
696 | args.learning_rate_pti = ( | ||
697 | args.learning_rate_pti * args.gradient_accumulation_steps * | ||
698 | args.train_batch_size * accelerator.num_processes | ||
699 | ) | ||
564 | 700 | ||
565 | if args.find_lr: | 701 | if args.find_lr: |
566 | args.learning_rate_unet = 1e-6 | 702 | args.learning_rate_unet = 1e-6 |
567 | args.learning_rate_text = 1e-6 | 703 | args.learning_rate_text = 1e-6 |
704 | args.learning_rate_pti = 1e-6 | ||
568 | args.lr_scheduler = "exponential_growth" | 705 | args.lr_scheduler = "exponential_growth" |
569 | 706 | ||
570 | if args.optimizer == 'adam8bit': | 707 | if args.optimizer == 'adam8bit': |
@@ -663,18 +800,25 @@ def main(): | |||
663 | accelerator=accelerator, | 800 | accelerator=accelerator, |
664 | unet=unet, | 801 | unet=unet, |
665 | text_encoder=text_encoder, | 802 | text_encoder=text_encoder, |
803 | tokenizer=tokenizer, | ||
666 | vae=vae, | 804 | vae=vae, |
667 | noise_scheduler=noise_scheduler, | 805 | noise_scheduler=noise_scheduler, |
668 | dtype=weight_dtype, | 806 | dtype=weight_dtype, |
807 | seed=args.seed, | ||
669 | guidance_scale=args.guidance_scale, | 808 | guidance_scale=args.guidance_scale, |
670 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 809 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
671 | no_val=args.valid_set_size == 0, | 810 | no_val=args.valid_set_size == 0, |
811 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
812 | offset_noise_strength=args.offset_noise_strength, | ||
813 | sample_scheduler=sample_scheduler, | ||
814 | sample_batch_size=args.sample_batch_size, | ||
815 | sample_num_batches=args.sample_batches, | ||
816 | sample_num_steps=args.sample_steps, | ||
817 | sample_image_size=args.sample_image_size, | ||
672 | ) | 818 | ) |
673 | 819 | ||
674 | checkpoint_output_dir = output_dir / "model" | 820 | create_datamodule = partial( |
675 | sample_output_dir = output_dir/"samples" | 821 | VlpnDataModule, |
676 | |||
677 | datamodule = VlpnDataModule( | ||
678 | data_file=args.train_data_file, | 822 | data_file=args.train_data_file, |
679 | batch_size=args.train_batch_size, | 823 | batch_size=args.train_batch_size, |
680 | tokenizer=tokenizer, | 824 | tokenizer=tokenizer, |
@@ -693,71 +837,146 @@ def main(): | |||
693 | train_set_pad=args.train_set_pad, | 837 | train_set_pad=args.train_set_pad, |
694 | valid_set_pad=args.valid_set_pad, | 838 | valid_set_pad=args.valid_set_pad, |
695 | seed=args.seed, | 839 | seed=args.seed, |
840 | dtype=weight_dtype, | ||
841 | ) | ||
842 | |||
843 | create_lr_scheduler = partial( | ||
844 | get_scheduler, | ||
845 | args.lr_scheduler, | ||
846 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
847 | min_lr=args.lr_min_lr, | ||
848 | warmup_func=args.lr_warmup_func, | ||
849 | annealing_func=args.lr_annealing_func, | ||
850 | warmup_exp=args.lr_warmup_exp, | ||
851 | annealing_exp=args.lr_annealing_exp, | ||
852 | cycles=args.lr_cycles, | ||
853 | end_lr=1e2, | ||
854 | warmup_epochs=args.lr_warmup_epochs, | ||
855 | mid_point=args.lr_mid_point, | ||
856 | ) | ||
857 | |||
858 | # PTI | ||
859 | # -------------------------------------------------------------------------------- | ||
860 | |||
861 | pti_output_dir = output_dir / "pti" | ||
862 | pti_checkpoint_output_dir = pti_output_dir / "model" | ||
863 | pti_sample_output_dir = pti_output_dir / "samples" | ||
864 | |||
865 | pti_datamodule = create_datamodule( | ||
866 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
867 | ) | ||
868 | pti_datamodule.setup() | ||
869 | |||
870 | num_pti_epochs = args.num_pti_epochs | ||
871 | pti_sample_frequency = args.sample_frequency | ||
872 | if num_pti_epochs is None: | ||
873 | num_pti_epochs = math.ceil( | ||
874 | args.num_pti_steps / len(pti_datamodule.train_dataset) | ||
875 | ) * args.gradient_accumulation_steps | ||
876 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | ||
877 | |||
878 | pti_optimizer = create_optimizer( | ||
879 | [ | ||
880 | { | ||
881 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | ||
882 | "lr": args.learning_rate_pti, | ||
883 | "weight_decay": 0, | ||
884 | }, | ||
885 | ] | ||
886 | ) | ||
887 | |||
888 | pti_lr_scheduler = create_lr_scheduler( | ||
889 | optimizer=pti_optimizer, | ||
890 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | ||
891 | train_epochs=num_pti_epochs, | ||
892 | ) | ||
893 | |||
894 | metrics = trainer( | ||
895 | strategy=textual_inversion_strategy, | ||
896 | project="ti", | ||
897 | train_dataloader=pti_datamodule.train_dataloader, | ||
898 | val_dataloader=pti_datamodule.val_dataloader, | ||
899 | optimizer=pti_optimizer, | ||
900 | lr_scheduler=pti_lr_scheduler, | ||
901 | num_train_epochs=num_pti_epochs, | ||
902 | # -- | ||
903 | sample_output_dir=pti_sample_output_dir, | ||
904 | checkpoint_output_dir=pti_checkpoint_output_dir, | ||
905 | sample_frequency=pti_sample_frequency, | ||
906 | placeholder_tokens=args.placeholder_tokens, | ||
907 | placeholder_token_ids=placeholder_token_ids, | ||
908 | use_emb_decay=args.use_emb_decay, | ||
909 | emb_decay_target=args.emb_decay_target, | ||
910 | emb_decay=args.emb_decay, | ||
911 | ) | ||
912 | |||
913 | plot_metrics(metrics, output_dir/"lr.png") | ||
914 | |||
915 | # LORA | ||
916 | # -------------------------------------------------------------------------------- | ||
917 | |||
918 | lora_output_dir = output_dir / "pti" | ||
919 | lora_checkpoint_output_dir = lora_output_dir / "model" | ||
920 | lora_sample_output_dir = lora_output_dir / "samples" | ||
921 | |||
922 | lora_datamodule = create_datamodule( | ||
696 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 923 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
697 | dtype=weight_dtype | ||
698 | ) | 924 | ) |
699 | datamodule.setup() | 925 | lora_datamodule.setup() |
700 | 926 | ||
701 | num_train_epochs = args.num_train_epochs | 927 | num_train_epochs = args.num_train_epochs |
702 | sample_frequency = args.sample_frequency | 928 | lora_sample_frequency = args.sample_frequency |
703 | if num_train_epochs is None: | 929 | if num_train_epochs is None: |
704 | num_train_epochs = math.ceil( | 930 | num_train_epochs = math.ceil( |
705 | args.num_train_steps / len(datamodule.train_dataset) | 931 | args.num_train_steps / len(lora_datamodule.train_dataset) |
706 | ) * args.gradient_accumulation_steps | 932 | ) * args.gradient_accumulation_steps |
707 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 933 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
708 | 934 | ||
709 | optimizer = create_optimizer( | 935 | lora_optimizer = create_optimizer( |
710 | [ | 936 | [ |
711 | { | 937 | { |
712 | "params": unet.parameters(), | 938 | "params": unet.parameters(), |
713 | "lr": args.learning_rate_unet, | 939 | "lr": args.learning_rate_unet, |
714 | }, | 940 | }, |
715 | { | 941 | { |
716 | "params": text_encoder.parameters(), | 942 | "params": itertools.chain( |
943 | text_encoder.text_model.encoder.parameters(), | ||
944 | text_encoder.text_model.final_layer_norm.parameters(), | ||
945 | ), | ||
946 | "lr": args.learning_rate_text, | ||
947 | }, | ||
948 | { | ||
949 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | ||
717 | "lr": args.learning_rate_text, | 950 | "lr": args.learning_rate_text, |
951 | "weight_decay": 0, | ||
718 | }, | 952 | }, |
719 | ] | 953 | ] |
720 | ) | 954 | ) |
721 | 955 | ||
722 | lr_scheduler = get_scheduler( | 956 | lora_lr_scheduler = create_lr_scheduler( |
723 | args.lr_scheduler, | 957 | optimizer=lora_optimizer, |
724 | optimizer=optimizer, | 958 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
725 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
726 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
727 | min_lr=args.lr_min_lr, | ||
728 | warmup_func=args.lr_warmup_func, | ||
729 | annealing_func=args.lr_annealing_func, | ||
730 | warmup_exp=args.lr_warmup_exp, | ||
731 | annealing_exp=args.lr_annealing_exp, | ||
732 | cycles=args.lr_cycles, | ||
733 | end_lr=1e2, | ||
734 | train_epochs=num_train_epochs, | 959 | train_epochs=num_train_epochs, |
735 | warmup_epochs=args.lr_warmup_epochs, | ||
736 | mid_point=args.lr_mid_point, | ||
737 | ) | 960 | ) |
738 | 961 | ||
739 | metrics = trainer( | 962 | metrics = trainer( |
740 | strategy=lora_strategy, | 963 | strategy=lora_strategy, |
741 | project="lora", | 964 | project="lora", |
742 | train_dataloader=datamodule.train_dataloader, | 965 | train_dataloader=lora_datamodule.train_dataloader, |
743 | val_dataloader=datamodule.val_dataloader, | 966 | val_dataloader=lora_datamodule.val_dataloader, |
744 | seed=args.seed, | 967 | optimizer=lora_optimizer, |
745 | optimizer=optimizer, | 968 | lr_scheduler=lora_lr_scheduler, |
746 | lr_scheduler=lr_scheduler, | ||
747 | num_train_epochs=num_train_epochs, | 969 | num_train_epochs=num_train_epochs, |
748 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
749 | sample_frequency=sample_frequency, | ||
750 | offset_noise_strength=args.offset_noise_strength, | ||
751 | # -- | 970 | # -- |
752 | tokenizer=tokenizer, | 971 | sample_output_dir=lora_sample_output_dir, |
753 | sample_scheduler=sample_scheduler, | 972 | checkpoint_output_dir=lora_checkpoint_output_dir, |
754 | sample_output_dir=sample_output_dir, | 973 | sample_frequency=lora_sample_frequency, |
755 | checkpoint_output_dir=checkpoint_output_dir, | 974 | placeholder_tokens=args.placeholder_tokens, |
975 | placeholder_token_ids=placeholder_token_ids, | ||
976 | use_emb_decay=args.use_emb_decay, | ||
977 | emb_decay_target=args.emb_decay_target, | ||
978 | emb_decay=args.emb_decay, | ||
756 | max_grad_norm=args.max_grad_norm, | 979 | max_grad_norm=args.max_grad_norm, |
757 | sample_batch_size=args.sample_batch_size, | ||
758 | sample_num_batches=args.sample_batches, | ||
759 | sample_num_steps=args.sample_steps, | ||
760 | sample_image_size=args.sample_image_size, | ||
761 | ) | 980 | ) |
762 | 981 | ||
763 | plot_metrics(metrics, output_dir/"lr.png") | 982 | plot_metrics(metrics, output_dir/"lr.png") |
diff --git a/train_ti.py b/train_ti.py index fc0d68c..344b412 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -750,6 +750,7 @@ def main(): | |||
750 | accelerator=accelerator, | 750 | accelerator=accelerator, |
751 | unet=unet, | 751 | unet=unet, |
752 | text_encoder=text_encoder, | 752 | text_encoder=text_encoder, |
753 | tokenizer=tokenizer, | ||
753 | vae=vae, | 754 | vae=vae, |
754 | noise_scheduler=noise_scheduler, | 755 | noise_scheduler=noise_scheduler, |
755 | dtype=weight_dtype, | 756 | dtype=weight_dtype, |
@@ -764,8 +765,6 @@ def main(): | |||
764 | global_step_offset=global_step_offset, | 765 | global_step_offset=global_step_offset, |
765 | offset_noise_strength=args.offset_noise_strength, | 766 | offset_noise_strength=args.offset_noise_strength, |
766 | # -- | 767 | # -- |
767 | tokenizer=tokenizer, | ||
768 | sample_scheduler=sample_scheduler, | ||
769 | checkpoint_output_dir=checkpoint_output_dir, | 768 | checkpoint_output_dir=checkpoint_output_dir, |
770 | use_emb_decay=args.use_emb_decay, | 769 | use_emb_decay=args.use_emb_decay, |
771 | emb_decay_target=args.emb_decay_target, | 770 | emb_decay_target=args.emb_decay_target, |
@@ -774,6 +773,7 @@ def main(): | |||
774 | ema_inv_gamma=args.ema_inv_gamma, | 773 | ema_inv_gamma=args.ema_inv_gamma, |
775 | ema_power=args.ema_power, | 774 | ema_power=args.ema_power, |
776 | ema_max_decay=args.ema_max_decay, | 775 | ema_max_decay=args.ema_max_decay, |
776 | sample_scheduler=sample_scheduler, | ||
777 | sample_batch_size=args.sample_batch_size, | 777 | sample_batch_size=args.sample_batch_size, |
778 | sample_num_batches=args.sample_batches, | 778 | sample_num_batches=args.sample_batches, |
779 | sample_num_steps=args.sample_steps, | 779 | sample_num_steps=args.sample_steps, |
@@ -863,9 +863,9 @@ def main(): | |||
863 | optimizer=optimizer, | 863 | optimizer=optimizer, |
864 | lr_scheduler=lr_scheduler, | 864 | lr_scheduler=lr_scheduler, |
865 | num_train_epochs=num_train_epochs, | 865 | num_train_epochs=num_train_epochs, |
866 | sample_frequency=sample_frequency, | ||
867 | # -- | 866 | # -- |
868 | sample_output_dir=sample_output_dir, | 867 | sample_output_dir=sample_output_dir, |
868 | sample_frequency=sample_frequency, | ||
869 | placeholder_tokens=placeholder_tokens, | 869 | placeholder_tokens=placeholder_tokens, |
870 | placeholder_token_ids=placeholder_token_ids, | 870 | placeholder_token_ids=placeholder_token_ids, |
871 | ) | 871 | ) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 209785a..d51a2f3 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -14,6 +14,8 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
14 | from peft import get_peft_model_state_dict | 14 | from peft import get_peft_model_state_dict |
15 | from safetensors.torch import save_file | 15 | from safetensors.torch import save_file |
16 | 16 | ||
17 | from slugify import slugify | ||
18 | |||
17 | from models.clip.tokenizer import MultiCLIPTokenizer | 19 | from models.clip.tokenizer import MultiCLIPTokenizer |
18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 20 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
19 | 21 | ||
@@ -30,6 +32,11 @@ def lora_strategy_callbacks( | |||
30 | sample_output_dir: Path, | 32 | sample_output_dir: Path, |
31 | checkpoint_output_dir: Path, | 33 | checkpoint_output_dir: Path, |
32 | seed: int, | 34 | seed: int, |
35 | placeholder_tokens: list[str], | ||
36 | placeholder_token_ids: list[list[int]], | ||
37 | use_emb_decay: bool = False, | ||
38 | emb_decay_target: float = 0.4, | ||
39 | emb_decay: float = 1e-2, | ||
33 | max_grad_norm: float = 1.0, | 40 | max_grad_norm: float = 1.0, |
34 | sample_batch_size: int = 1, | 41 | sample_batch_size: int = 1, |
35 | sample_num_batches: int = 1, | 42 | sample_num_batches: int = 1, |
@@ -77,6 +84,22 @@ def lora_strategy_callbacks( | |||
77 | max_grad_norm | 84 | max_grad_norm |
78 | ) | 85 | ) |
79 | 86 | ||
87 | if use_emb_decay: | ||
88 | return torch.stack([ | ||
89 | p | ||
90 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | ||
91 | if p.grad is not None | ||
92 | ]) | ||
93 | |||
94 | @torch.no_grad() | ||
95 | def on_after_optimize(w, lr: float): | ||
96 | if use_emb_decay: | ||
97 | lambda_ = emb_decay * lr | ||
98 | |||
99 | if lambda_ != 0: | ||
100 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
101 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
102 | |||
80 | @torch.no_grad() | 103 | @torch.no_grad() |
81 | def on_checkpoint(step, postfix): | 104 | def on_checkpoint(step, postfix): |
82 | if postfix != "end": | 105 | if postfix != "end": |
@@ -87,6 +110,12 @@ def lora_strategy_callbacks( | |||
87 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 110 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
88 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 111 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
89 | 112 | ||
113 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | ||
114 | text_encoder_.text_model.embeddings.save_embed( | ||
115 | ids, | ||
116 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | ||
117 | ) | ||
118 | |||
90 | lora_config = {} | 119 | lora_config = {} |
91 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 120 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) |
92 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 121 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
@@ -126,6 +155,7 @@ def lora_strategy_callbacks( | |||
126 | on_train=on_train, | 155 | on_train=on_train, |
127 | on_eval=on_eval, | 156 | on_eval=on_eval, |
128 | on_before_optimize=on_before_optimize, | 157 | on_before_optimize=on_before_optimize, |
158 | on_after_optimize=on_after_optimize, | ||
129 | on_checkpoint=on_checkpoint, | 159 | on_checkpoint=on_checkpoint, |
130 | on_sample=on_sample, | 160 | on_sample=on_sample, |
131 | ) | 161 | ) |
@@ -141,7 +171,12 @@ def lora_prepare( | |||
141 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 171 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
142 | **kwargs | 172 | **kwargs |
143 | ): | 173 | ): |
144 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) | 174 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
175 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
176 | |||
177 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | ||
178 | |||
179 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | ||
145 | 180 | ||
146 | 181 | ||
147 | lora_strategy = TrainingStrategy( | 182 | lora_strategy = TrainingStrategy( |