diff options
author | Volpeon <git@volpeon.ink> | 2023-04-29 16:35:41 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-29 16:35:41 +0200 |
commit | 74a5974ba30c170198890e59c92463bf5319fe64 (patch) | |
tree | 63637875e8c0a8707b0c413e3b2bbccad33f4db5 /train_lora.py | |
parent | Optional xformers (diff) | |
download | textual-inversion-diff-74a5974ba30c170198890e59c92463bf5319fe64.tar.gz textual-inversion-diff-74a5974ba30c170198890e59c92463bf5319fe64.tar.bz2 textual-inversion-diff-74a5974ba30c170198890e59c92463bf5319fe64.zip |
torch.compile won't work yet, keep code prepared
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index d95dbb9..3c8fc97 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -16,6 +16,7 @@ from accelerate import Accelerator | |||
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from peft import LoraConfig, LoraModel | 18 | from peft import LoraConfig, LoraModel |
19 | # from diffusers.models.attention_processor import AttnProcessor | ||
19 | import transformers | 20 | import transformers |
20 | 21 | ||
21 | import numpy as np | 22 | import numpy as np |
@@ -41,10 +42,11 @@ warnings.filterwarnings('ignore') | |||
41 | torch.backends.cuda.matmul.allow_tf32 = True | 42 | torch.backends.cuda.matmul.allow_tf32 = True |
42 | torch.backends.cudnn.benchmark = True | 43 | torch.backends.cudnn.benchmark = True |
43 | 44 | ||
44 | torch._dynamo.config.log_level = logging.ERROR | 45 | torch._dynamo.config.log_level = logging.WARNING |
45 | 46 | ||
46 | hidet.torch.dynamo_config.use_tensor_core(True) | 47 | hidet.torch.dynamo_config.use_tensor_core(True) |
47 | hidet.torch.dynamo_config.search_space(1) | 48 | # hidet.torch.dynamo_config.use_attention(True) |
49 | hidet.torch.dynamo_config.search_space(0) | ||
48 | 50 | ||
49 | 51 | ||
50 | def parse_args(): | 52 | def parse_args(): |
@@ -724,6 +726,19 @@ def main(): | |||
724 | if args.use_xformers: | 726 | if args.use_xformers: |
725 | vae.set_use_memory_efficient_attention_xformers(True) | 727 | vae.set_use_memory_efficient_attention_xformers(True) |
726 | unet.enable_xformers_memory_efficient_attention() | 728 | unet.enable_xformers_memory_efficient_attention() |
729 | # elif args.compile_unet: | ||
730 | # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
731 | # | ||
732 | # proc = AttnProcessor() | ||
733 | # | ||
734 | # def fn_recursive_set_proc(module: torch.nn.Module): | ||
735 | # if hasattr(module, "processor"): | ||
736 | # module.processor = proc | ||
737 | # | ||
738 | # for child in module.children(): | ||
739 | # fn_recursive_set_proc(child) | ||
740 | # | ||
741 | # fn_recursive_set_proc(unet) | ||
727 | 742 | ||
728 | if args.gradient_checkpointing: | 743 | if args.gradient_checkpointing: |
729 | unet.enable_gradient_checkpointing() | 744 | unet.enable_gradient_checkpointing() |