diff options
author | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
commit | b31fcb741432076f7e2f3ec9423ad935a08c6671 (patch) | |
tree | 2ab052d3bd617a56c4ea388c200da52cff39ba37 | |
parent | Fix for latest PEFT (diff) | |
download | textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2 textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip |
Support LoRA training for token embeddings
-rw-r--r-- | models/clip/embeddings.py | 3 | ||||
-rw-r--r-- | train_lora.py | 32 | ||||
-rw-r--r-- | train_ti.py | 10 | ||||
-rw-r--r-- | training/functional.py | 12 | ||||
-rw-r--r-- | training/strategy/lora.py | 4 |
5 files changed, 36 insertions, 25 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 2b23bd3..7c7f2ac 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -86,6 +86,9 @@ def patch_managed_embeddings( | |||
86 | alpha: int = 8, | 86 | alpha: int = 8, |
87 | dropout: float = 0.0 | 87 | dropout: float = 0.0 |
88 | ) -> ManagedCLIPTextEmbeddings: | 88 | ) -> ManagedCLIPTextEmbeddings: |
89 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): | ||
90 | return text_encoder.text_model.embeddings | ||
91 | |||
89 | text_embeddings = ManagedCLIPTextEmbeddings( | 92 | text_embeddings = ManagedCLIPTextEmbeddings( |
90 | text_encoder.config, | 93 | text_encoder.config, |
91 | text_encoder.text_model.embeddings, | 94 | text_encoder.text_model.embeddings, |
diff --git a/train_lora.py b/train_lora.py index dea58cf..167b17a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -22,17 +22,19 @@ import transformers | |||
22 | import numpy as np | 22 | import numpy as np |
23 | from slugify import slugify | 23 | from slugify import slugify |
24 | 24 | ||
25 | from util.files import load_config, load_embeddings_from_dir | ||
26 | from data.csv import VlpnDataModule, keyword_filter | 25 | from data.csv import VlpnDataModule, keyword_filter |
26 | from models.clip.embeddings import patch_managed_embeddings | ||
27 | from training.functional import train, add_placeholder_tokens, get_models | 27 | from training.functional import train, add_placeholder_tokens, get_models |
28 | from training.strategy.lora import lora_strategy | 28 | from training.strategy.lora import lora_strategy |
29 | from training.optimization import get_scheduler | 29 | from training.optimization import get_scheduler |
30 | from training.sampler import create_named_schedule_sampler | 30 | from training.sampler import create_named_schedule_sampler |
31 | from training.util import AverageMeter, save_args | 31 | from training.util import AverageMeter, save_args |
32 | from util.files import load_config, load_embeddings_from_dir | ||
32 | 33 | ||
33 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 34 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
34 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | 35 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] |
35 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | 36 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] |
37 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | ||
36 | 38 | ||
37 | 39 | ||
38 | logger = get_logger(__name__) | 40 | logger = get_logger(__name__) |
@@ -44,9 +46,9 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
44 | torch.backends.cudnn.benchmark = True | 46 | torch.backends.cudnn.benchmark = True |
45 | 47 | ||
46 | torch._dynamo.config.log_level = logging.WARNING | 48 | torch._dynamo.config.log_level = logging.WARNING |
49 | torch._dynamo.config.suppress_errors = True | ||
47 | 50 | ||
48 | hidet.torch.dynamo_config.use_tensor_core(True) | 51 | hidet.torch.dynamo_config.use_tensor_core(True) |
49 | hidet.torch.dynamo_config.use_attention(True) | ||
50 | hidet.torch.dynamo_config.search_space(0) | 52 | hidet.torch.dynamo_config.search_space(0) |
51 | 53 | ||
52 | 54 | ||
@@ -322,6 +324,11 @@ def parse_args(): | |||
322 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", | 324 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", |
323 | ) | 325 | ) |
324 | parser.add_argument( | 326 | parser.add_argument( |
327 | "--lora_text_encoder_emb", | ||
328 | action="store_true", | ||
329 | help="Include token embeddings in training. Prevents usage of TI techniques.", | ||
330 | ) | ||
331 | parser.add_argument( | ||
325 | "--train_text_encoder_cycles", | 332 | "--train_text_encoder_cycles", |
326 | default=999999, | 333 | default=999999, |
327 | help="Number of epochs the text encoder will be trained." | 334 | help="Number of epochs the text encoder will be trained." |
@@ -717,12 +724,13 @@ def main(): | |||
717 | 724 | ||
718 | save_args(output_dir, args) | 725 | save_args(output_dir, args) |
719 | 726 | ||
720 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 727 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) |
721 | args.pretrained_model_name_or_path, | ||
722 | args.emb_alpha, | ||
723 | args.emb_dropout | ||
724 | ) | ||
725 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 728 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) |
729 | |||
730 | def ensure_embeddings(): | ||
731 | if args.lora_text_encoder_emb: | ||
732 | raise ValueError("Can't use TI options when training token embeddings with LoRA") | ||
733 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | ||
726 | 734 | ||
727 | unet_config = LoraConfig( | 735 | unet_config = LoraConfig( |
728 | r=args.lora_r, | 736 | r=args.lora_r, |
@@ -736,7 +744,7 @@ def main(): | |||
736 | text_encoder_config = LoraConfig( | 744 | text_encoder_config = LoraConfig( |
737 | r=args.lora_text_encoder_r, | 745 | r=args.lora_text_encoder_r, |
738 | lora_alpha=args.lora_text_encoder_alpha, | 746 | lora_alpha=args.lora_text_encoder_alpha, |
739 | target_modules=TEXT_ENCODER_TARGET_MODULES, | 747 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, |
740 | lora_dropout=args.lora_text_encoder_dropout, | 748 | lora_dropout=args.lora_text_encoder_dropout, |
741 | bias=args.lora_text_encoder_bias, | 749 | bias=args.lora_text_encoder_bias, |
742 | ) | 750 | ) |
@@ -765,6 +773,8 @@ def main(): | |||
765 | unet.enable_gradient_checkpointing() | 773 | unet.enable_gradient_checkpointing() |
766 | 774 | ||
767 | if len(args.alias_tokens) != 0: | 775 | if len(args.alias_tokens) != 0: |
776 | embeddings = ensure_embeddings() | ||
777 | |||
768 | alias_placeholder_tokens = args.alias_tokens[::2] | 778 | alias_placeholder_tokens = args.alias_tokens[::2] |
769 | alias_initializer_tokens = args.alias_tokens[1::2] | 779 | alias_initializer_tokens = args.alias_tokens[1::2] |
770 | 780 | ||
@@ -781,6 +791,8 @@ def main(): | |||
781 | placeholder_token_ids = [] | 791 | placeholder_token_ids = [] |
782 | 792 | ||
783 | if args.embeddings_dir is not None: | 793 | if args.embeddings_dir is not None: |
794 | embeddings = ensure_embeddings() | ||
795 | |||
784 | embeddings_dir = Path(args.embeddings_dir) | 796 | embeddings_dir = Path(args.embeddings_dir) |
785 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 797 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
786 | raise ValueError("--embeddings_dir must point to an existing directory") | 798 | raise ValueError("--embeddings_dir must point to an existing directory") |
@@ -798,6 +810,8 @@ def main(): | |||
798 | embeddings.persist() | 810 | embeddings.persist() |
799 | 811 | ||
800 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 812 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
813 | embeddings = ensure_embeddings() | ||
814 | |||
801 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 815 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
802 | tokenizer=tokenizer, | 816 | tokenizer=tokenizer, |
803 | embeddings=embeddings, | 817 | embeddings=embeddings, |
@@ -997,6 +1011,8 @@ def main(): | |||
997 | # -------------------------------------------------------------------------------- | 1011 | # -------------------------------------------------------------------------------- |
998 | 1012 | ||
999 | if args.run_pti and len(placeholder_tokens) != 0: | 1013 | if args.run_pti and len(placeholder_tokens) != 0: |
1014 | embeddings = ensure_embeddings() | ||
1015 | |||
1000 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 1016 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] |
1001 | 1017 | ||
1002 | pti_datamodule = create_datamodule( | 1018 | pti_datamodule = create_datamodule( |
diff --git a/train_ti.py b/train_ti.py index 6fd974e..f60e3e5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -21,13 +21,14 @@ import transformers | |||
21 | import numpy as np | 21 | import numpy as np |
22 | from slugify import slugify | 22 | from slugify import slugify |
23 | 23 | ||
24 | from util.files import load_config, load_embeddings_from_dir | ||
25 | from data.csv import VlpnDataModule, keyword_filter | 24 | from data.csv import VlpnDataModule, keyword_filter |
25 | from models.clip.embeddings import patch_managed_embeddings | ||
26 | from training.functional import train, add_placeholder_tokens, get_models | 26 | from training.functional import train, add_placeholder_tokens, get_models |
27 | from training.strategy.ti import textual_inversion_strategy | 27 | from training.strategy.ti import textual_inversion_strategy |
28 | from training.optimization import get_scheduler | 28 | from training.optimization import get_scheduler |
29 | from training.sampler import create_named_schedule_sampler | 29 | from training.sampler import create_named_schedule_sampler |
30 | from training.util import AverageMeter, save_args | 30 | from training.util import AverageMeter, save_args |
31 | from util.files import load_config, load_embeddings_from_dir | ||
31 | 32 | ||
32 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
33 | 34 | ||
@@ -702,11 +703,8 @@ def main(): | |||
702 | 703 | ||
703 | save_args(output_dir, args) | 704 | save_args(output_dir, args) |
704 | 705 | ||
705 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 706 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) |
706 | args.pretrained_model_name_or_path, | 707 | embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) |
707 | args.emb_alpha, | ||
708 | args.emb_dropout | ||
709 | ) | ||
710 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 708 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) |
711 | 709 | ||
712 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 710 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
diff --git a/training/functional.py b/training/functional.py index 49c21c7..56c2995 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
20 | 20 | ||
21 | from data.csv import VlpnDataset | 21 | from data.csv import VlpnDataset |
22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings |
24 | from models.clip.util import get_extended_embeddings | 24 | from models.clip.util import get_extended_embeddings |
25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
26 | from models.convnext.discriminator import ConvNeXtDiscriminator | 26 | from models.convnext.discriminator import ConvNeXtDiscriminator |
@@ -68,11 +68,7 @@ class TrainingStrategy(): | |||
68 | prepare: TrainingStrategyPrepareCallable | 68 | prepare: TrainingStrategyPrepareCallable |
69 | 69 | ||
70 | 70 | ||
71 | def get_models( | 71 | def get_models(pretrained_model_name_or_path: str): |
72 | pretrained_model_name_or_path: str, | ||
73 | emb_alpha: int = 8, | ||
74 | emb_dropout: float = 0.0 | ||
75 | ): | ||
76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
78 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
@@ -81,9 +77,7 @@ def get_models( | |||
81 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
82 | pretrained_model_name_or_path, subfolder='scheduler') | 78 | pretrained_model_name_or_path, subfolder='scheduler') |
83 | 79 | ||
84 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) | 80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
85 | |||
86 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
87 | 81 | ||
88 | 82 | ||
89 | def save_samples( | 83 | def save_samples( |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0c0f633..f942b76 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -92,7 +92,7 @@ def lora_strategy_callbacks( | |||
92 | max_grad_norm | 92 | max_grad_norm |
93 | ) | 93 | ) |
94 | 94 | ||
95 | if use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
96 | params = [ | 96 | params = [ |
97 | p | 97 | p |
98 | for p in text_encoder.text_model.embeddings.parameters() | 98 | for p in text_encoder.text_model.embeddings.parameters() |
@@ -102,7 +102,7 @@ def lora_strategy_callbacks( | |||
102 | 102 | ||
103 | @torch.no_grad() | 103 | @torch.no_grad() |
104 | def on_after_optimize(w, lrs: dict[str, float]): | 104 | def on_after_optimize(w, lrs: dict[str, float]): |
105 | if use_emb_decay and w is not None and "emb" in lrs: | 105 | if w is not None and "emb" in lrs: |
106 | lr = lrs["emb"] | 106 | lr = lrs["emb"] |
107 | lambda_ = emb_decay * lr | 107 | lambda_ = emb_decay * lr |
108 | 108 | ||