diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-16 10:08:25 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-16 10:08:25 +0200 |
| commit | ee85af17159617637293d011f6225c753fd98ce7 (patch) | |
| tree | ce3da85e21934d1ae2968c4c82d8502d94d4c845 | |
| parent | Update deps (diff) | |
| download | textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.tar.gz textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.tar.bz2 textual-inversion-diff-ee85af17159617637293d011f6225c753fd98ce7.zip | |
Patch xformers to cast dtypes
| -rw-r--r-- | train_lora.py | 13 | ||||
| -rw-r--r-- | training/functional.py | 11 |
2 files changed, 17 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index 70f0dc8..a58bef7 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -17,6 +17,7 @@ from accelerate.logging import get_logger | |||
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from peft import LoraConfig, get_peft_model | 18 | from peft import LoraConfig, get_peft_model |
| 19 | # from diffusers.models.attention_processor import AttnProcessor | 19 | # from diffusers.models.attention_processor import AttnProcessor |
| 20 | from diffusers.utils.import_utils import is_xformers_available | ||
| 20 | import transformers | 21 | import transformers |
| 21 | 22 | ||
| 22 | import numpy as np | 23 | import numpy as np |
| @@ -33,7 +34,7 @@ from util.files import load_config, load_embeddings_from_dir | |||
| 33 | 34 | ||
| 34 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 35 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
| 35 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] | 36 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] |
| 36 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", key] | 37 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"] |
| 37 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] | 38 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] |
| 38 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] | 39 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] |
| 39 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | 40 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] |
| @@ -54,6 +55,16 @@ hidet.torch.dynamo_config.use_tensor_core(True) | |||
| 54 | hidet.torch.dynamo_config.search_space(0) | 55 | hidet.torch.dynamo_config.search_space(0) |
| 55 | 56 | ||
| 56 | 57 | ||
| 58 | if is_xformers_available(): | ||
| 59 | import xformers | ||
| 60 | import xformers.ops | ||
| 61 | |||
| 62 | orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention | ||
| 63 | def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs): | ||
| 64 | return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs) | ||
| 65 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
| 66 | |||
| 67 | |||
| 57 | def parse_args(): | 68 | def parse_args(): |
| 58 | parser = argparse.ArgumentParser( | 69 | parser = argparse.ArgumentParser( |
| 59 | description="Simple example of a training script." | 70 | description="Simple example of a training script." |
diff --git a/training/functional.py b/training/functional.py index 56c2995..fd3f9f4 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -68,14 +68,13 @@ class TrainingStrategy(): | |||
| 68 | prepare: TrainingStrategyPrepareCallable | 68 | prepare: TrainingStrategyPrepareCallable |
| 69 | 69 | ||
| 70 | 70 | ||
| 71 | def get_models(pretrained_model_name_or_path: str): | 71 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): |
| 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) |
| 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) |
| 75 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 75 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) |
| 76 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 76 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
| 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
| 78 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 79 | 78 | ||
| 80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 79 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
| 81 | 80 | ||
