From 5188d1e4c1d6aed3c510e5a534e39e40d8a3365d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 29 Apr 2023 10:47:08 +0200 Subject: Optional xformers --- train_lora.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/train_lora.py b/train_lora.py index 74afeed..d95dbb9 100644 --- a/train_lora.py +++ b/train_lora.py @@ -44,7 +44,7 @@ torch.backends.cudnn.benchmark = True torch._dynamo.config.log_level = logging.ERROR hidet.torch.dynamo_config.use_tensor_core(True) -hidet.torch.dynamo_config.search_space(2) +hidet.torch.dynamo_config.search_space(1) def parse_args(): @@ -460,6 +460,11 @@ def parse_args(): action="store_true", help="Compile UNet with Torch Dynamo.", ) + parser.add_argument( + "--use_xformers", + action="store_true", + help="Use xformers.", + ) parser.add_argument( "--lora_rank", type=int, @@ -715,8 +720,10 @@ def main(): text_encoder = LoraModel(text_encoder_config, text_encoder) vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.enable_xformers_memory_efficient_attention() + + if args.use_xformers: + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() -- cgit v1.2.3-70-g09d2