summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py32
1 files changed, 24 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py
index dea58cf..167b17a 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -22,17 +22,19 @@ import transformers
22import numpy as np 22import numpy as np
23from slugify import slugify 23from slugify import slugify
24 24
25from util.files import load_config, load_embeddings_from_dir
26from data.csv import VlpnDataModule, keyword_filter 25from data.csv import VlpnDataModule, keyword_filter
26from models.clip.embeddings import patch_managed_embeddings
27from training.functional import train, add_placeholder_tokens, get_models 27from training.functional import train, add_placeholder_tokens, get_models
28from training.strategy.lora import lora_strategy 28from training.strategy.lora import lora_strategy
29from training.optimization import get_scheduler 29from training.optimization import get_scheduler
30from training.sampler import create_named_schedule_sampler 30from training.sampler import create_named_schedule_sampler
31from training.util import AverageMeter, save_args 31from training.util import AverageMeter, save_args
32from util.files import load_config, load_embeddings_from_dir
32 33
33# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 34# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
34UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] 35UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
35TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] 36TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
37TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"]
36 38
37 39
38logger = get_logger(__name__) 40logger = get_logger(__name__)
@@ -44,9 +46,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
44torch.backends.cudnn.benchmark = True 46torch.backends.cudnn.benchmark = True
45 47
46torch._dynamo.config.log_level = logging.WARNING 48torch._dynamo.config.log_level = logging.WARNING
49torch._dynamo.config.suppress_errors = True
47 50
48hidet.torch.dynamo_config.use_tensor_core(True) 51hidet.torch.dynamo_config.use_tensor_core(True)
49hidet.torch.dynamo_config.use_attention(True)
50hidet.torch.dynamo_config.search_space(0) 52hidet.torch.dynamo_config.search_space(0)
51 53
52 54
@@ -322,6 +324,11 @@ def parse_args():
322 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", 324 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
323 ) 325 )
324 parser.add_argument( 326 parser.add_argument(
327 "--lora_text_encoder_emb",
328 action="store_true",
329 help="Include token embeddings in training. Prevents usage of TI techniques.",
330 )
331 parser.add_argument(
325 "--train_text_encoder_cycles", 332 "--train_text_encoder_cycles",
326 default=999999, 333 default=999999,
327 help="Number of epochs the text encoder will be trained." 334 help="Number of epochs the text encoder will be trained."
@@ -717,12 +724,13 @@ def main():
717 724
718 save_args(output_dir, args) 725 save_args(output_dir, args)
719 726
720 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 727 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path)
721 args.pretrained_model_name_or_path,
722 args.emb_alpha,
723 args.emb_dropout
724 )
725 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) 728 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps)
729
730 def ensure_embeddings():
731 if args.lora_text_encoder_emb:
732 raise ValueError("Can't use TI options when training token embeddings with LoRA")
733 return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout)
726 734
727 unet_config = LoraConfig( 735 unet_config = LoraConfig(
728 r=args.lora_r, 736 r=args.lora_r,
@@ -736,7 +744,7 @@ def main():
736 text_encoder_config = LoraConfig( 744 text_encoder_config = LoraConfig(
737 r=args.lora_text_encoder_r, 745 r=args.lora_text_encoder_r,
738 lora_alpha=args.lora_text_encoder_alpha, 746 lora_alpha=args.lora_text_encoder_alpha,
739 target_modules=TEXT_ENCODER_TARGET_MODULES, 747 target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES,
740 lora_dropout=args.lora_text_encoder_dropout, 748 lora_dropout=args.lora_text_encoder_dropout,
741 bias=args.lora_text_encoder_bias, 749 bias=args.lora_text_encoder_bias,
742 ) 750 )
@@ -765,6 +773,8 @@ def main():
765 unet.enable_gradient_checkpointing() 773 unet.enable_gradient_checkpointing()
766 774
767 if len(args.alias_tokens) != 0: 775 if len(args.alias_tokens) != 0:
776 embeddings = ensure_embeddings()
777
768 alias_placeholder_tokens = args.alias_tokens[::2] 778 alias_placeholder_tokens = args.alias_tokens[::2]
769 alias_initializer_tokens = args.alias_tokens[1::2] 779 alias_initializer_tokens = args.alias_tokens[1::2]
770 780
@@ -781,6 +791,8 @@ def main():
781 placeholder_token_ids = [] 791 placeholder_token_ids = []
782 792
783 if args.embeddings_dir is not None: 793 if args.embeddings_dir is not None:
794 embeddings = ensure_embeddings()
795
784 embeddings_dir = Path(args.embeddings_dir) 796 embeddings_dir = Path(args.embeddings_dir)
785 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 797 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
786 raise ValueError("--embeddings_dir must point to an existing directory") 798 raise ValueError("--embeddings_dir must point to an existing directory")
@@ -798,6 +810,8 @@ def main():
798 embeddings.persist() 810 embeddings.persist()
799 811
800 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: 812 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings:
813 embeddings = ensure_embeddings()
814
801 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 815 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
802 tokenizer=tokenizer, 816 tokenizer=tokenizer,
803 embeddings=embeddings, 817 embeddings=embeddings,
@@ -997,6 +1011,8 @@ def main():
997 # -------------------------------------------------------------------------------- 1011 # --------------------------------------------------------------------------------
998 1012
999 if args.run_pti and len(placeholder_tokens) != 0: 1013 if args.run_pti and len(placeholder_tokens) != 0:
1014 embeddings = ensure_embeddings()
1015
1000 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] 1016 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens]
1001 1017
1002 pti_datamodule = create_datamodule( 1018 pti_datamodule = create_datamodule(