summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 10:08:25 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 10:08:25 +0200
commitee85af17159617637293d011f6225c753fd98ce7 (patch)
treece3da85e21934d1ae2968c4c82d8502d94d4c845 /training
parentUpdate deps (diff)
downloadtextual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.tar.gz
textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.tar.bz2
textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.zip
Patch xformers to cast dtypes
Diffstat (limited to 'training')
-rw-r--r--training/functional.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py
index 56c2995..fd3f9f4 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -68,14 +68,13 @@ class TrainingStrategy():
68 prepare: TrainingStrategyPrepareCallable 68 prepare: TrainingStrategyPrepareCallable
69 69
70 70
71def get_models(pretrained_model_name_or_path: str): 71def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32):
72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype)
74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype)
75 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 75 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype)
76 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 76 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
77 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 77 sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
78 pretrained_model_name_or_path, subfolder='scheduler')
79 78
80 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler 79 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
81 80