diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py index dea58cf..167b17a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -22,17 +22,19 @@ import transformers | |||
22 | import numpy as np | 22 | import numpy as np |
23 | from slugify import slugify | 23 | from slugify import slugify |
24 | 24 | ||
25 | from util.files import load_config, load_embeddings_from_dir | ||
26 | from data.csv import VlpnDataModule, keyword_filter | 25 | from data.csv import VlpnDataModule, keyword_filter |
26 | from models.clip.embeddings import patch_managed_embeddings | ||
27 | from training.functional import train, add_placeholder_tokens, get_models | 27 | from training.functional import train, add_placeholder_tokens, get_models |
28 | from training.strategy.lora import lora_strategy | 28 | from training.strategy.lora import lora_strategy |
29 | from training.optimization import get_scheduler | 29 | from training.optimization import get_scheduler |
30 | from training.sampler import create_named_schedule_sampler | 30 | from training.sampler import create_named_schedule_sampler |
31 | from training.util import AverageMeter, save_args | 31 | from training.util import AverageMeter, save_args |
32 | from util.files import load_config, load_embeddings_from_dir | ||
32 | 33 | ||
33 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 34 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
34 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | 35 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] |
35 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | 36 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] |
37 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | ||
36 | 38 | ||
37 | 39 | ||
38 | logger = get_logger(__name__) | 40 | logger = get_logger(__name__) |
@@ -44,9 +46,9 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
44 | torch.backends.cudnn.benchmark = True | 46 | torch.backends.cudnn.benchmark = True |
45 | 47 | ||
46 | torch._dynamo.config.log_level = logging.WARNING | 48 | torch._dynamo.config.log_level = logging.WARNING |
49 | torch._dynamo.config.suppress_errors = True | ||
47 | 50 | ||
48 | hidet.torch.dynamo_config.use_tensor_core(True) | 51 | hidet.torch.dynamo_config.use_tensor_core(True) |
49 | hidet.torch.dynamo_config.use_attention(True) | ||
50 | hidet.torch.dynamo_config.search_space(0) | 52 | hidet.torch.dynamo_config.search_space(0) |
51 | 53 | ||
52 | 54 | ||
@@ -322,6 +324,11 @@ def parse_args(): | |||
322 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", | 324 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", |
323 | ) | 325 | ) |
324 | parser.add_argument( | 326 | parser.add_argument( |
327 | "--lora_text_encoder_emb", | ||
328 | action="store_true", | ||
329 | help="Include token embeddings in training. Prevents usage of TI techniques.", | ||
330 | ) | ||
331 | parser.add_argument( | ||
325 | "--train_text_encoder_cycles", | 332 | "--train_text_encoder_cycles", |
326 | default=999999, | 333 | default=999999, |
327 | help="Number of epochs the text encoder will be trained." | 334 | help="Number of epochs the text encoder will be trained." |
@@ -717,12 +724,13 @@ def main(): | |||
717 | 724 | ||
718 | save_args(output_dir, args) | 725 | save_args(output_dir, args) |
719 | 726 | ||
720 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 727 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) |
721 | args.pretrained_model_name_or_path, | ||
722 | args.emb_alpha, | ||
723 | args.emb_dropout | ||
724 | ) | ||
725 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 728 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) |
729 | |||
730 | def ensure_embeddings(): | ||
731 | if args.lora_text_encoder_emb: | ||
732 | raise ValueError("Can't use TI options when training token embeddings with LoRA") | ||
733 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | ||
726 | 734 | ||
727 | unet_config = LoraConfig( | 735 | unet_config = LoraConfig( |
728 | r=args.lora_r, | 736 | r=args.lora_r, |
@@ -736,7 +744,7 @@ def main(): | |||
736 | text_encoder_config = LoraConfig( | 744 | text_encoder_config = LoraConfig( |
737 | r=args.lora_text_encoder_r, | 745 | r=args.lora_text_encoder_r, |
738 | lora_alpha=args.lora_text_encoder_alpha, | 746 | lora_alpha=args.lora_text_encoder_alpha, |
739 | target_modules=TEXT_ENCODER_TARGET_MODULES, | 747 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, |
740 | lora_dropout=args.lora_text_encoder_dropout, | 748 | lora_dropout=args.lora_text_encoder_dropout, |
741 | bias=args.lora_text_encoder_bias, | 749 | bias=args.lora_text_encoder_bias, |
742 | ) | 750 | ) |
@@ -765,6 +773,8 @@ def main(): | |||
765 | unet.enable_gradient_checkpointing() | 773 | unet.enable_gradient_checkpointing() |
766 | 774 | ||
767 | if len(args.alias_tokens) != 0: | 775 | if len(args.alias_tokens) != 0: |
776 | embeddings = ensure_embeddings() | ||
777 | |||
768 | alias_placeholder_tokens = args.alias_tokens[::2] | 778 | alias_placeholder_tokens = args.alias_tokens[::2] |
769 | alias_initializer_tokens = args.alias_tokens[1::2] | 779 | alias_initializer_tokens = args.alias_tokens[1::2] |
770 | 780 | ||
@@ -781,6 +791,8 @@ def main(): | |||
781 | placeholder_token_ids = [] | 791 | placeholder_token_ids = [] |
782 | 792 | ||
783 | if args.embeddings_dir is not None: | 793 | if args.embeddings_dir is not None: |
794 | embeddings = ensure_embeddings() | ||
795 | |||
784 | embeddings_dir = Path(args.embeddings_dir) | 796 | embeddings_dir = Path(args.embeddings_dir) |
785 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 797 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
786 | raise ValueError("--embeddings_dir must point to an existing directory") | 798 | raise ValueError("--embeddings_dir must point to an existing directory") |
@@ -798,6 +810,8 @@ def main(): | |||
798 | embeddings.persist() | 810 | embeddings.persist() |
799 | 811 | ||
800 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 812 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
813 | embeddings = ensure_embeddings() | ||
814 | |||
801 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 815 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
802 | tokenizer=tokenizer, | 816 | tokenizer=tokenizer, |
803 | embeddings=embeddings, | 817 | embeddings=embeddings, |
@@ -997,6 +1011,8 @@ def main(): | |||
997 | # -------------------------------------------------------------------------------- | 1011 | # -------------------------------------------------------------------------------- |
998 | 1012 | ||
999 | if args.run_pti and len(placeholder_tokens) != 0: | 1013 | if args.run_pti and len(placeholder_tokens) != 0: |
1014 | embeddings = ensure_embeddings() | ||
1015 | |||
1000 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 1016 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] |
1001 | 1017 | ||
1002 | pti_datamodule = create_datamodule( | 1018 | pti_datamodule = create_datamodule( |