summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 10:08:25 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 10:08:25 +0200
commitee85af17159617637293d011f6225c753fd98ce7 (patch)
treece3da85e21934d1ae2968c4c82d8502d94d4c845
parentUpdate deps (diff)
downloadtextual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.tar.gz
textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.tar.bz2
textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.zip
Patch xformers to cast dtypes
-rw-r--r--train_lora.py13
-rw-r--r--training/functional.py11
2 files changed, 17 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py
index 70f0dc8..a58bef7 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -17,6 +17,7 @@ from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from peft import LoraConfig, get_peft_model 18from peft import LoraConfig, get_peft_model
19# from diffusers.models.attention_processor import AttnProcessor 19# from diffusers.models.attention_processor import AttnProcessor
20from diffusers.utils.import_utils import is_xformers_available
20import transformers 21import transformers
21 22
22import numpy as np 23import numpy as np
@@ -33,7 +34,7 @@ from util.files import load_config, load_embeddings_from_dir
33 34
34# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 35# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
35UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] 36UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"]
36UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", key] 37UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"]
37TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] 38TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"]
38TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] 39TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"]
39TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] 40TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"]
@@ -54,6 +55,16 @@ hidet.torch.dynamo_config.use_tensor_core(True)
54hidet.torch.dynamo_config.search_space(0) 55hidet.torch.dynamo_config.search_space(0)
55 56
56 57
58if is_xformers_available():
59 import xformers
60 import xformers.ops
61
62 orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
63 def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs):
64 return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs)
65 xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention
66
67
57def parse_args(): 68def parse_args():
58 parser = argparse.ArgumentParser( 69 parser = argparse.ArgumentParser(
59 description="Simple example of a training script." 70 description="Simple example of a training script."
diff --git a/training/functional.py b/training/functional.py
index 56c2995..fd3f9f4 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -68,14 +68,13 @@ class TrainingStrategy():
68 prepare: TrainingStrategyPrepareCallable 68 prepare: TrainingStrategyPrepareCallable
69 69
70 70
71def get_models(pretrained_model_name_or_path: str): 71def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32):
72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype)
74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype)
75 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 75 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype)
76 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 76 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
77 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 77 sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
78 pretrained_model_name_or_path, subfolder='scheduler')
79 78
80 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler 79 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
81 80