diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-16 10:08:25 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-16 10:08:25 +0200 | 
| commit | ee85af17159617637293d011f6225c753fd98ce7 (patch) | |
| tree | ce3da85e21934d1ae2968c4c82d8502d94d4c845 /training | |
| parent | Update deps (diff) | |
| download | textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.tar.gz textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.tar.bz2 textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.zip  | |
Patch xformers to cast dtypes
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 11 | 
1 files changed, 5 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index 56c2995..fd3f9f4 100644 --- a/training/functional.py +++ b/training/functional.py  | |||
| @@ -68,14 +68,13 @@ class TrainingStrategy(): | |||
| 68 | prepare: TrainingStrategyPrepareCallable | 68 | prepare: TrainingStrategyPrepareCallable | 
| 69 | 69 | ||
| 70 | 70 | ||
| 71 | def get_models(pretrained_model_name_or_path: str): | 71 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): | 
| 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 
| 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) | 
| 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) | 
| 75 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 75 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) | 
| 76 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 76 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 
| 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 
| 78 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 79 | 78 | ||
| 80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 79 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 
| 81 | 80 | ||
