diff options
author | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
commit | b31fcb741432076f7e2f3ec9423ad935a08c6671 (patch) | |
tree | 2ab052d3bd617a56c4ea388c200da52cff39ba37 /training | |
parent | Fix for latest PEFT (diff) | |
download | textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2 textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip |
Support LoRA training for token embeddings
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 12 | ||||
-rw-r--r-- | training/strategy/lora.py | 4 |
2 files changed, 5 insertions, 11 deletions
diff --git a/training/functional.py b/training/functional.py index 49c21c7..56c2995 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
20 | 20 | ||
21 | from data.csv import VlpnDataset | 21 | from data.csv import VlpnDataset |
22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings |
24 | from models.clip.util import get_extended_embeddings | 24 | from models.clip.util import get_extended_embeddings |
25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
26 | from models.convnext.discriminator import ConvNeXtDiscriminator | 26 | from models.convnext.discriminator import ConvNeXtDiscriminator |
@@ -68,11 +68,7 @@ class TrainingStrategy(): | |||
68 | prepare: TrainingStrategyPrepareCallable | 68 | prepare: TrainingStrategyPrepareCallable |
69 | 69 | ||
70 | 70 | ||
71 | def get_models( | 71 | def get_models(pretrained_model_name_or_path: str): |
72 | pretrained_model_name_or_path: str, | ||
73 | emb_alpha: int = 8, | ||
74 | emb_dropout: float = 0.0 | ||
75 | ): | ||
76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
78 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
@@ -81,9 +77,7 @@ def get_models( | |||
81 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
82 | pretrained_model_name_or_path, subfolder='scheduler') | 78 | pretrained_model_name_or_path, subfolder='scheduler') |
83 | 79 | ||
84 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) | 80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
85 | |||
86 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
87 | 81 | ||
88 | 82 | ||
89 | def save_samples( | 83 | def save_samples( |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0c0f633..f942b76 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -92,7 +92,7 @@ def lora_strategy_callbacks( | |||
92 | max_grad_norm | 92 | max_grad_norm |
93 | ) | 93 | ) |
94 | 94 | ||
95 | if use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
96 | params = [ | 96 | params = [ |
97 | p | 97 | p |
98 | for p in text_encoder.text_model.embeddings.parameters() | 98 | for p in text_encoder.text_model.embeddings.parameters() |
@@ -102,7 +102,7 @@ def lora_strategy_callbacks( | |||
102 | 102 | ||
103 | @torch.no_grad() | 103 | @torch.no_grad() |
104 | def on_after_optimize(w, lrs: dict[str, float]): | 104 | def on_after_optimize(w, lrs: dict[str, float]): |
105 | if use_emb_decay and w is not None and "emb" in lrs: | 105 | if w is not None and "emb" in lrs: |
106 | lr = lrs["emb"] | 106 | lr = lrs["emb"] |
107 | lambda_ = emb_decay * lr | 107 | lambda_ = emb_decay * lr |
108 | 108 | ||