From b31fcb741432076f7e2f3ec9423ad935a08c6671 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 07:12:14 +0200 Subject: Support LoRA training for token embeddings --- training/functional.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 49c21c7..56c2995 100644 --- a/training/functional.py +++ b/training/functional.py @@ -20,7 +20,7 @@ from tqdm.auto import tqdm from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings +from models.clip.embeddings import ManagedCLIPTextEmbeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from models.convnext.discriminator import ConvNeXtDiscriminator @@ -68,11 +68,7 @@ class TrainingStrategy(): prepare: TrainingStrategyPrepareCallable -def get_models( - pretrained_model_name_or_path: str, - emb_alpha: int = 8, - emb_dropout: float = 0.0 -): +def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -81,9 +77,7 @@ def get_models( sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) - - return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings + return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler def save_samples( -- cgit v1.2.3-70-g09d2