diff options
| -rw-r--r-- | models/clip/embeddings.py | 3 | ||||
| -rw-r--r-- | train_lora.py | 32 | ||||
| -rw-r--r-- | train_ti.py | 10 | ||||
| -rw-r--r-- | training/functional.py | 12 | ||||
| -rw-r--r-- | training/strategy/lora.py | 4 |
5 files changed, 36 insertions, 25 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 2b23bd3..7c7f2ac 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -86,6 +86,9 @@ def patch_managed_embeddings( | |||
| 86 | alpha: int = 8, | 86 | alpha: int = 8, |
| 87 | dropout: float = 0.0 | 87 | dropout: float = 0.0 |
| 88 | ) -> ManagedCLIPTextEmbeddings: | 88 | ) -> ManagedCLIPTextEmbeddings: |
| 89 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): | ||
| 90 | return text_encoder.text_model.embeddings | ||
| 91 | |||
| 89 | text_embeddings = ManagedCLIPTextEmbeddings( | 92 | text_embeddings = ManagedCLIPTextEmbeddings( |
| 90 | text_encoder.config, | 93 | text_encoder.config, |
| 91 | text_encoder.text_model.embeddings, | 94 | text_encoder.text_model.embeddings, |
diff --git a/train_lora.py b/train_lora.py index dea58cf..167b17a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -22,17 +22,19 @@ import transformers | |||
| 22 | import numpy as np | 22 | import numpy as np |
| 23 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | 24 | ||
| 25 | from util.files import load_config, load_embeddings_from_dir | ||
| 26 | from data.csv import VlpnDataModule, keyword_filter | 25 | from data.csv import VlpnDataModule, keyword_filter |
| 26 | from models.clip.embeddings import patch_managed_embeddings | ||
| 27 | from training.functional import train, add_placeholder_tokens, get_models | 27 | from training.functional import train, add_placeholder_tokens, get_models |
| 28 | from training.strategy.lora import lora_strategy | 28 | from training.strategy.lora import lora_strategy |
| 29 | from training.optimization import get_scheduler | 29 | from training.optimization import get_scheduler |
| 30 | from training.sampler import create_named_schedule_sampler | 30 | from training.sampler import create_named_schedule_sampler |
| 31 | from training.util import AverageMeter, save_args | 31 | from training.util import AverageMeter, save_args |
| 32 | from util.files import load_config, load_embeddings_from_dir | ||
| 32 | 33 | ||
| 33 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 34 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
| 34 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | 35 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] |
| 35 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | 36 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] |
| 37 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | ||
| 36 | 38 | ||
| 37 | 39 | ||
| 38 | logger = get_logger(__name__) | 40 | logger = get_logger(__name__) |
| @@ -44,9 +46,9 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
| 44 | torch.backends.cudnn.benchmark = True | 46 | torch.backends.cudnn.benchmark = True |
| 45 | 47 | ||
| 46 | torch._dynamo.config.log_level = logging.WARNING | 48 | torch._dynamo.config.log_level = logging.WARNING |
| 49 | torch._dynamo.config.suppress_errors = True | ||
| 47 | 50 | ||
| 48 | hidet.torch.dynamo_config.use_tensor_core(True) | 51 | hidet.torch.dynamo_config.use_tensor_core(True) |
| 49 | hidet.torch.dynamo_config.use_attention(True) | ||
| 50 | hidet.torch.dynamo_config.search_space(0) | 52 | hidet.torch.dynamo_config.search_space(0) |
| 51 | 53 | ||
| 52 | 54 | ||
| @@ -322,6 +324,11 @@ def parse_args(): | |||
| 322 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", | 324 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", |
| 323 | ) | 325 | ) |
| 324 | parser.add_argument( | 326 | parser.add_argument( |
| 327 | "--lora_text_encoder_emb", | ||
| 328 | action="store_true", | ||
| 329 | help="Include token embeddings in training. Prevents usage of TI techniques.", | ||
| 330 | ) | ||
| 331 | parser.add_argument( | ||
| 325 | "--train_text_encoder_cycles", | 332 | "--train_text_encoder_cycles", |
| 326 | default=999999, | 333 | default=999999, |
| 327 | help="Number of epochs the text encoder will be trained." | 334 | help="Number of epochs the text encoder will be trained." |
| @@ -717,12 +724,13 @@ def main(): | |||
| 717 | 724 | ||
| 718 | save_args(output_dir, args) | 725 | save_args(output_dir, args) |
| 719 | 726 | ||
| 720 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 727 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) |
| 721 | args.pretrained_model_name_or_path, | ||
| 722 | args.emb_alpha, | ||
| 723 | args.emb_dropout | ||
| 724 | ) | ||
| 725 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 728 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) |
| 729 | |||
| 730 | def ensure_embeddings(): | ||
| 731 | if args.lora_text_encoder_emb: | ||
| 732 | raise ValueError("Can't use TI options when training token embeddings with LoRA") | ||
| 733 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | ||
| 726 | 734 | ||
| 727 | unet_config = LoraConfig( | 735 | unet_config = LoraConfig( |
| 728 | r=args.lora_r, | 736 | r=args.lora_r, |
| @@ -736,7 +744,7 @@ def main(): | |||
| 736 | text_encoder_config = LoraConfig( | 744 | text_encoder_config = LoraConfig( |
| 737 | r=args.lora_text_encoder_r, | 745 | r=args.lora_text_encoder_r, |
| 738 | lora_alpha=args.lora_text_encoder_alpha, | 746 | lora_alpha=args.lora_text_encoder_alpha, |
| 739 | target_modules=TEXT_ENCODER_TARGET_MODULES, | 747 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, |
| 740 | lora_dropout=args.lora_text_encoder_dropout, | 748 | lora_dropout=args.lora_text_encoder_dropout, |
| 741 | bias=args.lora_text_encoder_bias, | 749 | bias=args.lora_text_encoder_bias, |
| 742 | ) | 750 | ) |
| @@ -765,6 +773,8 @@ def main(): | |||
| 765 | unet.enable_gradient_checkpointing() | 773 | unet.enable_gradient_checkpointing() |
| 766 | 774 | ||
| 767 | if len(args.alias_tokens) != 0: | 775 | if len(args.alias_tokens) != 0: |
| 776 | embeddings = ensure_embeddings() | ||
| 777 | |||
| 768 | alias_placeholder_tokens = args.alias_tokens[::2] | 778 | alias_placeholder_tokens = args.alias_tokens[::2] |
| 769 | alias_initializer_tokens = args.alias_tokens[1::2] | 779 | alias_initializer_tokens = args.alias_tokens[1::2] |
| 770 | 780 | ||
| @@ -781,6 +791,8 @@ def main(): | |||
| 781 | placeholder_token_ids = [] | 791 | placeholder_token_ids = [] |
| 782 | 792 | ||
| 783 | if args.embeddings_dir is not None: | 793 | if args.embeddings_dir is not None: |
| 794 | embeddings = ensure_embeddings() | ||
| 795 | |||
| 784 | embeddings_dir = Path(args.embeddings_dir) | 796 | embeddings_dir = Path(args.embeddings_dir) |
| 785 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 797 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 786 | raise ValueError("--embeddings_dir must point to an existing directory") | 798 | raise ValueError("--embeddings_dir must point to an existing directory") |
| @@ -798,6 +810,8 @@ def main(): | |||
| 798 | embeddings.persist() | 810 | embeddings.persist() |
| 799 | 811 | ||
| 800 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 812 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
| 813 | embeddings = ensure_embeddings() | ||
| 814 | |||
| 801 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 815 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 802 | tokenizer=tokenizer, | 816 | tokenizer=tokenizer, |
| 803 | embeddings=embeddings, | 817 | embeddings=embeddings, |
| @@ -997,6 +1011,8 @@ def main(): | |||
| 997 | # -------------------------------------------------------------------------------- | 1011 | # -------------------------------------------------------------------------------- |
| 998 | 1012 | ||
| 999 | if args.run_pti and len(placeholder_tokens) != 0: | 1013 | if args.run_pti and len(placeholder_tokens) != 0: |
| 1014 | embeddings = ensure_embeddings() | ||
| 1015 | |||
| 1000 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 1016 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] |
| 1001 | 1017 | ||
| 1002 | pti_datamodule = create_datamodule( | 1018 | pti_datamodule = create_datamodule( |
diff --git a/train_ti.py b/train_ti.py index 6fd974e..f60e3e5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -21,13 +21,14 @@ import transformers | |||
| 21 | import numpy as np | 21 | import numpy as np |
| 22 | from slugify import slugify | 22 | from slugify import slugify |
| 23 | 23 | ||
| 24 | from util.files import load_config, load_embeddings_from_dir | ||
| 25 | from data.csv import VlpnDataModule, keyword_filter | 24 | from data.csv import VlpnDataModule, keyword_filter |
| 25 | from models.clip.embeddings import patch_managed_embeddings | ||
| 26 | from training.functional import train, add_placeholder_tokens, get_models | 26 | from training.functional import train, add_placeholder_tokens, get_models |
| 27 | from training.strategy.ti import textual_inversion_strategy | 27 | from training.strategy.ti import textual_inversion_strategy |
| 28 | from training.optimization import get_scheduler | 28 | from training.optimization import get_scheduler |
| 29 | from training.sampler import create_named_schedule_sampler | 29 | from training.sampler import create_named_schedule_sampler |
| 30 | from training.util import AverageMeter, save_args | 30 | from training.util import AverageMeter, save_args |
| 31 | from util.files import load_config, load_embeddings_from_dir | ||
| 31 | 32 | ||
| 32 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
| 33 | 34 | ||
| @@ -702,11 +703,8 @@ def main(): | |||
| 702 | 703 | ||
| 703 | save_args(output_dir, args) | 704 | save_args(output_dir, args) |
| 704 | 705 | ||
| 705 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 706 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) |
| 706 | args.pretrained_model_name_or_path, | 707 | embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) |
| 707 | args.emb_alpha, | ||
| 708 | args.emb_dropout | ||
| 709 | ) | ||
| 710 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 708 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) |
| 711 | 709 | ||
| 712 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 710 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
diff --git a/training/functional.py b/training/functional.py index 49c21c7..56c2995 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
| 20 | 20 | ||
| 21 | from data.csv import VlpnDataset | 21 | from data.csv import VlpnDataset |
| 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings |
| 24 | from models.clip.util import get_extended_embeddings | 24 | from models.clip.util import get_extended_embeddings |
| 25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 26 | from models.convnext.discriminator import ConvNeXtDiscriminator | 26 | from models.convnext.discriminator import ConvNeXtDiscriminator |
| @@ -68,11 +68,7 @@ class TrainingStrategy(): | |||
| 68 | prepare: TrainingStrategyPrepareCallable | 68 | prepare: TrainingStrategyPrepareCallable |
| 69 | 69 | ||
| 70 | 70 | ||
| 71 | def get_models( | 71 | def get_models(pretrained_model_name_or_path: str): |
| 72 | pretrained_model_name_or_path: str, | ||
| 73 | emb_alpha: int = 8, | ||
| 74 | emb_dropout: float = 0.0 | ||
| 75 | ): | ||
| 76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 78 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| @@ -81,9 +77,7 @@ def get_models( | |||
| 81 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 82 | pretrained_model_name_or_path, subfolder='scheduler') | 78 | pretrained_model_name_or_path, subfolder='scheduler') |
| 83 | 79 | ||
| 84 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) | 80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
| 85 | |||
| 86 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
| 87 | 81 | ||
| 88 | 82 | ||
| 89 | def save_samples( | 83 | def save_samples( |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0c0f633..f942b76 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -92,7 +92,7 @@ def lora_strategy_callbacks( | |||
| 92 | max_grad_norm | 92 | max_grad_norm |
| 93 | ) | 93 | ) |
| 94 | 94 | ||
| 95 | if use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
| 96 | params = [ | 96 | params = [ |
| 97 | p | 97 | p |
| 98 | for p in text_encoder.text_model.embeddings.parameters() | 98 | for p in text_encoder.text_model.embeddings.parameters() |
| @@ -102,7 +102,7 @@ def lora_strategy_callbacks( | |||
| 102 | 102 | ||
| 103 | @torch.no_grad() | 103 | @torch.no_grad() |
| 104 | def on_after_optimize(w, lrs: dict[str, float]): | 104 | def on_after_optimize(w, lrs: dict[str, float]): |
| 105 | if use_emb_decay and w is not None and "emb" in lrs: | 105 | if w is not None and "emb" in lrs: |
| 106 | lr = lrs["emb"] | 106 | lr = lrs["emb"] |
| 107 | lambda_ = emb_decay * lr | 107 | lambda_ = emb_decay * lr |
| 108 | 108 | ||
