From ee85af17159617637293d011f6225c753fd98ce7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 10:08:25 +0200 Subject: Patch xformers to cast dtypes --- train_lora.py | 13 ++++++++++++- training/functional.py | 11 +++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/train_lora.py b/train_lora.py index 70f0dc8..a58bef7 100644 --- a/train_lora.py +++ b/train_lora.py @@ -17,6 +17,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from peft import LoraConfig, get_peft_model # from diffusers.models.attention_processor import AttnProcessor +from diffusers.utils.import_utils import is_xformers_available import transformers import numpy as np @@ -33,7 +34,7 @@ from util.files import load_config, load_embeddings_from_dir # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] -UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", key] +UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"] TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] @@ -54,6 +55,16 @@ hidet.torch.dynamo_config.use_tensor_core(True) hidet.torch.dynamo_config.search_space(0) +if is_xformers_available(): + import xformers + import xformers.ops + + orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention + def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs): + return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs) + xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention + + def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." diff --git a/training/functional.py b/training/functional.py index 56c2995..fd3f9f4 100644 --- a/training/functional.py +++ b/training/functional.py @@ -68,14 +68,13 @@ class TrainingStrategy(): prepare: TrainingStrategyPrepareCallable -def get_models(pretrained_model_name_or_path: str): +def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) + unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - sample_scheduler = UniPCMultistepScheduler.from_pretrained( - pretrained_model_name_or_path, subfolder='scheduler') + sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler -- cgit v1.2.3-70-g09d2