diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 | 
| commit | b31fcb741432076f7e2f3ec9423ad935a08c6671 (patch) | |
| tree | 2ab052d3bd617a56c4ea388c200da52cff39ba37 /training | |
| parent | Fix for latest PEFT (diff) | |
| download | textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2 textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip  | |
Support LoRA training for token embeddings
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 12 | ||||
| -rw-r--r-- | training/strategy/lora.py | 4 | 
2 files changed, 5 insertions, 11 deletions
diff --git a/training/functional.py b/training/functional.py index 49c21c7..56c2995 100644 --- a/training/functional.py +++ b/training/functional.py  | |||
| @@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
| 20 | 20 | ||
| 21 | from data.csv import VlpnDataset | 21 | from data.csv import VlpnDataset | 
| 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 
| 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | 
| 24 | from models.clip.util import get_extended_embeddings | 24 | from models.clip.util import get_extended_embeddings | 
| 25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer | 
| 26 | from models.convnext.discriminator import ConvNeXtDiscriminator | 26 | from models.convnext.discriminator import ConvNeXtDiscriminator | 
| @@ -68,11 +68,7 @@ class TrainingStrategy(): | |||
| 68 | prepare: TrainingStrategyPrepareCallable | 68 | prepare: TrainingStrategyPrepareCallable | 
| 69 | 69 | ||
| 70 | 70 | ||
| 71 | def get_models( | 71 | def get_models(pretrained_model_name_or_path: str): | 
| 72 | pretrained_model_name_or_path: str, | ||
| 73 | emb_alpha: int = 8, | ||
| 74 | emb_dropout: float = 0.0 | ||
| 75 | ): | ||
| 76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 
| 77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 
| 78 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 
| @@ -81,9 +77,7 @@ def get_models( | |||
| 81 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 
| 82 | pretrained_model_name_or_path, subfolder='scheduler') | 78 | pretrained_model_name_or_path, subfolder='scheduler') | 
| 83 | 79 | ||
| 84 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) | 80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 
| 85 | |||
| 86 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
| 87 | 81 | ||
| 88 | 82 | ||
| 89 | def save_samples( | 83 | def save_samples( | 
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0c0f633..f942b76 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py  | |||
| @@ -92,7 +92,7 @@ def lora_strategy_callbacks( | |||
| 92 | max_grad_norm | 92 | max_grad_norm | 
| 93 | ) | 93 | ) | 
| 94 | 94 | ||
| 95 | if use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: | 
| 96 | params = [ | 96 | params = [ | 
| 97 | p | 97 | p | 
| 98 | for p in text_encoder.text_model.embeddings.parameters() | 98 | for p in text_encoder.text_model.embeddings.parameters() | 
| @@ -102,7 +102,7 @@ def lora_strategy_callbacks( | |||
| 102 | 102 | ||
| 103 | @torch.no_grad() | 103 | @torch.no_grad() | 
| 104 | def on_after_optimize(w, lrs: dict[str, float]): | 104 | def on_after_optimize(w, lrs: dict[str, float]): | 
| 105 | if use_emb_decay and w is not None and "emb" in lrs: | 105 | if w is not None and "emb" in lrs: | 
| 106 | lr = lrs["emb"] | 106 | lr = lrs["emb"] | 
| 107 | lambda_ = emb_decay * lr | 107 | lambda_ = emb_decay * lr | 
| 108 | 108 | ||
