diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index d95dbb9..3c8fc97 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -16,6 +16,7 @@ from accelerate import Accelerator | |||
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from peft import LoraConfig, LoraModel | 18 | from peft import LoraConfig, LoraModel |
| 19 | # from diffusers.models.attention_processor import AttnProcessor | ||
| 19 | import transformers | 20 | import transformers |
| 20 | 21 | ||
| 21 | import numpy as np | 22 | import numpy as np |
| @@ -41,10 +42,11 @@ warnings.filterwarnings('ignore') | |||
| 41 | torch.backends.cuda.matmul.allow_tf32 = True | 42 | torch.backends.cuda.matmul.allow_tf32 = True |
| 42 | torch.backends.cudnn.benchmark = True | 43 | torch.backends.cudnn.benchmark = True |
| 43 | 44 | ||
| 44 | torch._dynamo.config.log_level = logging.ERROR | 45 | torch._dynamo.config.log_level = logging.WARNING |
| 45 | 46 | ||
| 46 | hidet.torch.dynamo_config.use_tensor_core(True) | 47 | hidet.torch.dynamo_config.use_tensor_core(True) |
| 47 | hidet.torch.dynamo_config.search_space(1) | 48 | # hidet.torch.dynamo_config.use_attention(True) |
| 49 | hidet.torch.dynamo_config.search_space(0) | ||
| 48 | 50 | ||
| 49 | 51 | ||
| 50 | def parse_args(): | 52 | def parse_args(): |
| @@ -724,6 +726,19 @@ def main(): | |||
| 724 | if args.use_xformers: | 726 | if args.use_xformers: |
| 725 | vae.set_use_memory_efficient_attention_xformers(True) | 727 | vae.set_use_memory_efficient_attention_xformers(True) |
| 726 | unet.enable_xformers_memory_efficient_attention() | 728 | unet.enable_xformers_memory_efficient_attention() |
| 729 | # elif args.compile_unet: | ||
| 730 | # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
| 731 | # | ||
| 732 | # proc = AttnProcessor() | ||
| 733 | # | ||
| 734 | # def fn_recursive_set_proc(module: torch.nn.Module): | ||
| 735 | # if hasattr(module, "processor"): | ||
| 736 | # module.processor = proc | ||
| 737 | # | ||
| 738 | # for child in module.children(): | ||
| 739 | # fn_recursive_set_proc(child) | ||
| 740 | # | ||
| 741 | # fn_recursive_set_proc(unet) | ||
| 727 | 742 | ||
| 728 | if args.gradient_checkpointing: | 743 | if args.gradient_checkpointing: |
| 729 | unet.enable_gradient_checkpointing() | 744 | unet.enable_gradient_checkpointing() |
