From ba9fd1a10746d85d2502c8a79ac49db63d346b04 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Apr 2023 11:29:31 +0200 Subject: Update --- training/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 7d49782..e14aeea 100644 --- a/training/functional.py +++ b/training/functional.py @@ -72,7 +72,7 @@ def make_grid(images, rows, cols): return grid -def get_models(pretrained_model_name_or_path: str): +def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -81,7 +81,7 @@ def get_models(pretrained_model_name_or_path: str): sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder) + embeddings = patch_managed_embeddings(text_encoder, emb_dropout) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings -- cgit v1.2.3-70-g09d2