summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py
index 56c2995..fd3f9f4 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -68,14 +68,13 @@ class TrainingStrategy():
68 prepare: TrainingStrategyPrepareCallable 68 prepare: TrainingStrategyPrepareCallable
69 69
70 70
71def get_models(pretrained_model_name_or_path: str): 71def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32):
72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype)
74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype)
75 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 75 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype)
76 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 76 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
77 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 77 sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
78 pretrained_model_name_or_path, subfolder='scheduler')
79 78
80 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler 79 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
81 80