From 74a5974ba30c170198890e59c92463bf5319fe64 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 29 Apr 2023 16:35:41 +0200 Subject: torch.compile won't work yet, keep code prepared --- train_lora.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index d95dbb9..3c8fc97 100644 --- a/train_lora.py +++ b/train_lora.py @@ -16,6 +16,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from peft import LoraConfig, LoraModel +# from diffusers.models.attention_processor import AttnProcessor import transformers import numpy as np @@ -41,10 +42,11 @@ warnings.filterwarnings('ignore') torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True -torch._dynamo.config.log_level = logging.ERROR +torch._dynamo.config.log_level = logging.WARNING hidet.torch.dynamo_config.use_tensor_core(True) -hidet.torch.dynamo_config.search_space(1) +# hidet.torch.dynamo_config.use_attention(True) +hidet.torch.dynamo_config.search_space(0) def parse_args(): @@ -724,6 +726,19 @@ def main(): if args.use_xformers: vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() + # elif args.compile_unet: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False + # + # proc = AttnProcessor() + # + # def fn_recursive_set_proc(module: torch.nn.Module): + # if hasattr(module, "processor"): + # module.processor = proc + # + # for child in module.children(): + # fn_recursive_set_proc(child) + # + # fn_recursive_set_proc(unet) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() -- cgit v1.2.3-54-g00ecf