From b31fcb741432076f7e2f3ec9423ad935a08c6671 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 07:12:14 +0200 Subject: Support LoRA training for token embeddings --- training/functional.py | 12 +++--------- training/strategy/lora.py | 4 ++-- 2 files changed, 5 insertions(+), 11 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 49c21c7..56c2995 100644 --- a/training/functional.py +++ b/training/functional.py @@ -20,7 +20,7 @@ from tqdm.auto import tqdm from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings +from models.clip.embeddings import ManagedCLIPTextEmbeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from models.convnext.discriminator import ConvNeXtDiscriminator @@ -68,11 +68,7 @@ class TrainingStrategy(): prepare: TrainingStrategyPrepareCallable -def get_models( - pretrained_model_name_or_path: str, - emb_alpha: int = 8, - emb_dropout: float = 0.0 -): +def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -81,9 +77,7 @@ def get_models( sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) - - return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings + return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler def save_samples( diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0c0f633..f942b76 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -92,7 +92,7 @@ def lora_strategy_callbacks( max_grad_norm ) - if use_emb_decay: + if len(placeholder_tokens) != 0 and use_emb_decay: params = [ p for p in text_encoder.text_model.embeddings.parameters() @@ -102,7 +102,7 @@ def lora_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): - if use_emb_decay and w is not None and "emb" in lrs: + if w is not None and "emb" in lrs: lr = lrs["emb"] lambda_ = emb_decay * lr -- cgit v1.2.3-54-g00ecf