summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 07:12:14 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 07:12:14 +0200
commitb31fcb741432076f7e2f3ec9423ad935a08c6671 (patch)
tree2ab052d3bd617a56c4ea388c200da52cff39ba37
parentFix for latest PEFT (diff)
downloadtextual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz
textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2
textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip
Support LoRA training for token embeddings
-rw-r--r--models/clip/embeddings.py3
-rw-r--r--train_lora.py32
-rw-r--r--train_ti.py10
-rw-r--r--training/functional.py12
-rw-r--r--training/strategy/lora.py4
5 files changed, 36 insertions, 25 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 2b23bd3..7c7f2ac 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -86,6 +86,9 @@ def patch_managed_embeddings(
86 alpha: int = 8, 86 alpha: int = 8,
87 dropout: float = 0.0 87 dropout: float = 0.0
88) -> ManagedCLIPTextEmbeddings: 88) -> ManagedCLIPTextEmbeddings:
89 if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings):
90 return text_encoder.text_model.embeddings
91
89 text_embeddings = ManagedCLIPTextEmbeddings( 92 text_embeddings = ManagedCLIPTextEmbeddings(
90 text_encoder.config, 93 text_encoder.config,
91 text_encoder.text_model.embeddings, 94 text_encoder.text_model.embeddings,
diff --git a/train_lora.py b/train_lora.py
index dea58cf..167b17a 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -22,17 +22,19 @@ import transformers
22import numpy as np 22import numpy as np
23from slugify import slugify 23from slugify import slugify
24 24
25from util.files import load_config, load_embeddings_from_dir
26from data.csv import VlpnDataModule, keyword_filter 25from data.csv import VlpnDataModule, keyword_filter
26from models.clip.embeddings import patch_managed_embeddings
27from training.functional import train, add_placeholder_tokens, get_models 27from training.functional import train, add_placeholder_tokens, get_models
28from training.strategy.lora import lora_strategy 28from training.strategy.lora import lora_strategy
29from training.optimization import get_scheduler 29from training.optimization import get_scheduler
30from training.sampler import create_named_schedule_sampler 30from training.sampler import create_named_schedule_sampler
31from training.util import AverageMeter, save_args 31from training.util import AverageMeter, save_args
32from util.files import load_config, load_embeddings_from_dir
32 33
33# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 34# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
34UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] 35UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
35TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] 36TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
37TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"]
36 38
37 39
38logger = get_logger(__name__) 40logger = get_logger(__name__)
@@ -44,9 +46,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
44torch.backends.cudnn.benchmark = True 46torch.backends.cudnn.benchmark = True
45 47
46torch._dynamo.config.log_level = logging.WARNING 48torch._dynamo.config.log_level = logging.WARNING
49torch._dynamo.config.suppress_errors = True
47 50
48hidet.torch.dynamo_config.use_tensor_core(True) 51hidet.torch.dynamo_config.use_tensor_core(True)
49hidet.torch.dynamo_config.use_attention(True)
50hidet.torch.dynamo_config.search_space(0) 52hidet.torch.dynamo_config.search_space(0)
51 53
52 54
@@ -322,6 +324,11 @@ def parse_args():
322 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", 324 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
323 ) 325 )
324 parser.add_argument( 326 parser.add_argument(
327 "--lora_text_encoder_emb",
328 action="store_true",
329 help="Include token embeddings in training. Prevents usage of TI techniques.",
330 )
331 parser.add_argument(
325 "--train_text_encoder_cycles", 332 "--train_text_encoder_cycles",
326 default=999999, 333 default=999999,
327 help="Number of epochs the text encoder will be trained." 334 help="Number of epochs the text encoder will be trained."
@@ -717,12 +724,13 @@ def main():
717 724
718 save_args(output_dir, args) 725 save_args(output_dir, args)
719 726
720 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 727 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path)
721 args.pretrained_model_name_or_path,
722 args.emb_alpha,
723 args.emb_dropout
724 )
725 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) 728 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps)
729
730 def ensure_embeddings():
731 if args.lora_text_encoder_emb:
732 raise ValueError("Can't use TI options when training token embeddings with LoRA")
733 return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout)
726 734
727 unet_config = LoraConfig( 735 unet_config = LoraConfig(
728 r=args.lora_r, 736 r=args.lora_r,
@@ -736,7 +744,7 @@ def main():
736 text_encoder_config = LoraConfig( 744 text_encoder_config = LoraConfig(
737 r=args.lora_text_encoder_r, 745 r=args.lora_text_encoder_r,
738 lora_alpha=args.lora_text_encoder_alpha, 746 lora_alpha=args.lora_text_encoder_alpha,
739 target_modules=TEXT_ENCODER_TARGET_MODULES, 747 target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES,
740 lora_dropout=args.lora_text_encoder_dropout, 748 lora_dropout=args.lora_text_encoder_dropout,
741 bias=args.lora_text_encoder_bias, 749 bias=args.lora_text_encoder_bias,
742 ) 750 )
@@ -765,6 +773,8 @@ def main():
765 unet.enable_gradient_checkpointing() 773 unet.enable_gradient_checkpointing()
766 774
767 if len(args.alias_tokens) != 0: 775 if len(args.alias_tokens) != 0:
776 embeddings = ensure_embeddings()
777
768 alias_placeholder_tokens = args.alias_tokens[::2] 778 alias_placeholder_tokens = args.alias_tokens[::2]
769 alias_initializer_tokens = args.alias_tokens[1::2] 779 alias_initializer_tokens = args.alias_tokens[1::2]
770 780
@@ -781,6 +791,8 @@ def main():
781 placeholder_token_ids = [] 791 placeholder_token_ids = []
782 792
783 if args.embeddings_dir is not None: 793 if args.embeddings_dir is not None:
794 embeddings = ensure_embeddings()
795
784 embeddings_dir = Path(args.embeddings_dir) 796 embeddings_dir = Path(args.embeddings_dir)
785 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 797 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
786 raise ValueError("--embeddings_dir must point to an existing directory") 798 raise ValueError("--embeddings_dir must point to an existing directory")
@@ -798,6 +810,8 @@ def main():
798 embeddings.persist() 810 embeddings.persist()
799 811
800 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: 812 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings:
813 embeddings = ensure_embeddings()
814
801 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 815 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
802 tokenizer=tokenizer, 816 tokenizer=tokenizer,
803 embeddings=embeddings, 817 embeddings=embeddings,
@@ -997,6 +1011,8 @@ def main():
997 # -------------------------------------------------------------------------------- 1011 # --------------------------------------------------------------------------------
998 1012
999 if args.run_pti and len(placeholder_tokens) != 0: 1013 if args.run_pti and len(placeholder_tokens) != 0:
1014 embeddings = ensure_embeddings()
1015
1000 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] 1016 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens]
1001 1017
1002 pti_datamodule = create_datamodule( 1018 pti_datamodule = create_datamodule(
diff --git a/train_ti.py b/train_ti.py
index 6fd974e..f60e3e5 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -21,13 +21,14 @@ import transformers
21import numpy as np 21import numpy as np
22from slugify import slugify 22from slugify import slugify
23 23
24from util.files import load_config, load_embeddings_from_dir
25from data.csv import VlpnDataModule, keyword_filter 24from data.csv import VlpnDataModule, keyword_filter
25from models.clip.embeddings import patch_managed_embeddings
26from training.functional import train, add_placeholder_tokens, get_models 26from training.functional import train, add_placeholder_tokens, get_models
27from training.strategy.ti import textual_inversion_strategy 27from training.strategy.ti import textual_inversion_strategy
28from training.optimization import get_scheduler 28from training.optimization import get_scheduler
29from training.sampler import create_named_schedule_sampler 29from training.sampler import create_named_schedule_sampler
30from training.util import AverageMeter, save_args 30from training.util import AverageMeter, save_args
31from util.files import load_config, load_embeddings_from_dir
31 32
32logger = get_logger(__name__) 33logger = get_logger(__name__)
33 34
@@ -702,11 +703,8 @@ def main():
702 703
703 save_args(output_dir, args) 704 save_args(output_dir, args)
704 705
705 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 706 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path)
706 args.pretrained_model_name_or_path, 707 embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout)
707 args.emb_alpha,
708 args.emb_dropout
709 )
710 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) 708 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps)
711 709
712 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 710 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
diff --git a/training/functional.py b/training/functional.py
index 49c21c7..56c2995 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm
20 20
21from data.csv import VlpnDataset 21from data.csv import VlpnDataset
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 23from models.clip.embeddings import ManagedCLIPTextEmbeddings
24from models.clip.util import get_extended_embeddings 24from models.clip.util import get_extended_embeddings
25from models.clip.tokenizer import MultiCLIPTokenizer 25from models.clip.tokenizer import MultiCLIPTokenizer
26from models.convnext.discriminator import ConvNeXtDiscriminator 26from models.convnext.discriminator import ConvNeXtDiscriminator
@@ -68,11 +68,7 @@ class TrainingStrategy():
68 prepare: TrainingStrategyPrepareCallable 68 prepare: TrainingStrategyPrepareCallable
69 69
70 70
71def get_models( 71def get_models(pretrained_model_name_or_path: str):
72 pretrained_model_name_or_path: str,
73 emb_alpha: int = 8,
74 emb_dropout: float = 0.0
75):
76 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
77 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
78 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -81,9 +77,7 @@ def get_models(
81 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 77 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
82 pretrained_model_name_or_path, subfolder='scheduler') 78 pretrained_model_name_or_path, subfolder='scheduler')
83 79
84 embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) 80 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
85
86 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
87 81
88 82
89def save_samples( 83def save_samples(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 0c0f633..f942b76 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -92,7 +92,7 @@ def lora_strategy_callbacks(
92 max_grad_norm 92 max_grad_norm
93 ) 93 )
94 94
95 if use_emb_decay: 95 if len(placeholder_tokens) != 0 and use_emb_decay:
96 params = [ 96 params = [
97 p 97 p
98 for p in text_encoder.text_model.embeddings.parameters() 98 for p in text_encoder.text_model.embeddings.parameters()
@@ -102,7 +102,7 @@ def lora_strategy_callbacks(
102 102
103 @torch.no_grad() 103 @torch.no_grad()
104 def on_after_optimize(w, lrs: dict[str, float]): 104 def on_after_optimize(w, lrs: dict[str, float]):
105 if use_emb_decay and w is not None and "emb" in lrs: 105 if w is not None and "emb" in lrs:
106 lr = lrs["emb"] 106 lr = lrs["emb"]
107 lambda_ = emb_decay * lr 107 lambda_ = emb_decay * lr
108 108