summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py19
1 files changed, 17 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py
index d95dbb9..3c8fc97 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -16,6 +16,7 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from peft import LoraConfig, LoraModel 18from peft import LoraConfig, LoraModel
19# from diffusers.models.attention_processor import AttnProcessor
19import transformers 20import transformers
20 21
21import numpy as np 22import numpy as np
@@ -41,10 +42,11 @@ warnings.filterwarnings('ignore')
41torch.backends.cuda.matmul.allow_tf32 = True 42torch.backends.cuda.matmul.allow_tf32 = True
42torch.backends.cudnn.benchmark = True 43torch.backends.cudnn.benchmark = True
43 44
44torch._dynamo.config.log_level = logging.ERROR 45torch._dynamo.config.log_level = logging.WARNING
45 46
46hidet.torch.dynamo_config.use_tensor_core(True) 47hidet.torch.dynamo_config.use_tensor_core(True)
47hidet.torch.dynamo_config.search_space(1) 48# hidet.torch.dynamo_config.use_attention(True)
49hidet.torch.dynamo_config.search_space(0)
48 50
49 51
50def parse_args(): 52def parse_args():
@@ -724,6 +726,19 @@ def main():
724 if args.use_xformers: 726 if args.use_xformers:
725 vae.set_use_memory_efficient_attention_xformers(True) 727 vae.set_use_memory_efficient_attention_xformers(True)
726 unet.enable_xformers_memory_efficient_attention() 728 unet.enable_xformers_memory_efficient_attention()
729 # elif args.compile_unet:
730 # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False
731 #
732 # proc = AttnProcessor()
733 #
734 # def fn_recursive_set_proc(module: torch.nn.Module):
735 # if hasattr(module, "processor"):
736 # module.processor = proc
737 #
738 # for child in module.children():
739 # fn_recursive_set_proc(child)
740 #
741 # fn_recursive_set_proc(unet)
727 742
728 if args.gradient_checkpointing: 743 if args.gradient_checkpointing:
729 unet.enable_gradient_checkpointing() 744 unet.enable_gradient_checkpointing()