From ee85af17159617637293d011f6225c753fd98ce7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 10:08:25 +0200 Subject: Patch xformers to cast dtypes --- training/functional.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 56c2995..fd3f9f4 100644 --- a/training/functional.py +++ b/training/functional.py @@ -68,14 +68,13 @@ class TrainingStrategy(): prepare: TrainingStrategyPrepareCallable -def get_models(pretrained_model_name_or_path: str): +def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) + unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - sample_scheduler = UniPCMultistepScheduler.from_pretrained( - pretrained_model_name_or_path, subfolder='scheduler') + sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler -- cgit v1.2.3-70-g09d2