diff options
author | Volpeon <git@volpeon.ink> | 2023-04-29 16:35:41 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-29 16:35:41 +0200 |
commit | 74a5974ba30c170198890e59c92463bf5319fe64 (patch) | |
tree | 63637875e8c0a8707b0c413e3b2bbccad33f4db5 | |
parent | Optional xformers (diff) | |
download | textual-inversion-diff-74a5974ba30c170198890e59c92463bf5319fe64.tar.gz textual-inversion-diff-74a5974ba30c170198890e59c92463bf5319fe64.tar.bz2 textual-inversion-diff-74a5974ba30c170198890e59c92463bf5319fe64.zip |
torch.compile won't work yet, keep code prepared
-rw-r--r-- | environment.yaml | 3 | ||||
-rw-r--r-- | environment_nightly.yaml | 31 | ||||
-rw-r--r-- | train_lora.py | 19 | ||||
-rw-r--r-- | training/functional.py | 3 |
4 files changed, 52 insertions, 4 deletions
diff --git a/environment.yaml b/environment.yaml index b161244..dfbafaf 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -4,7 +4,7 @@ channels: | |||
4 | - nvidia | 4 | - nvidia |
5 | - xformers/label/dev | 5 | - xformers/label/dev |
6 | - defaults | 6 | - defaults |
7 | - conda-forge | 7 | - conda-forge |
8 | dependencies: | 8 | dependencies: |
9 | - gcc=11.3.0 | 9 | - gcc=11.3.0 |
10 | - gxx=11.3.0 | 10 | - gxx=11.3.0 |
@@ -30,4 +30,3 @@ dependencies: | |||
30 | - test-tube>=0.7.5 | 30 | - test-tube>=0.7.5 |
31 | - timm==0.8.17.dev0 | 31 | - timm==0.8.17.dev0 |
32 | - transformers==4.28.1 | 32 | - transformers==4.28.1 |
33 | - triton==2.0.0.post1 | ||
diff --git a/environment_nightly.yaml b/environment_nightly.yaml new file mode 100644 index 0000000..4c5c798 --- /dev/null +++ b/environment_nightly.yaml | |||
@@ -0,0 +1,31 @@ | |||
1 | name: ldd | ||
2 | channels: | ||
3 | - pytorch-nightly | ||
4 | - nvidia | ||
5 | - xformers/label/dev | ||
6 | - defaults | ||
7 | - conda-forge | ||
8 | dependencies: | ||
9 | - cuda-nvcc=12.1.105 | ||
10 | - matplotlib=3.6.2 | ||
11 | - numpy=1.24.3 | ||
12 | - pip=22.3.1 | ||
13 | - python=3.10.8 | ||
14 | - pytorch=2.1.0.dev20230429=*cuda12.1* | ||
15 | - torchvision=0.16.0.dev20230429 | ||
16 | # - xformers=0.0.19 | ||
17 | - pip: | ||
18 | - -e . | ||
19 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate | ||
20 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | ||
21 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | ||
22 | - bitsandbytes==0.38.1 | ||
23 | - hidet==0.2.3 | ||
24 | - lion-pytorch==0.0.7 | ||
25 | - peft==0.2.0 | ||
26 | - python-slugify>=6.1.2 | ||
27 | - safetensors==0.3.1 | ||
28 | - setuptools==65.6.3 | ||
29 | - test-tube>=0.7.5 | ||
30 | - timm==0.8.17.dev0 | ||
31 | - transformers==4.28.1 | ||
diff --git a/train_lora.py b/train_lora.py index d95dbb9..3c8fc97 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -16,6 +16,7 @@ from accelerate import Accelerator | |||
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from peft import LoraConfig, LoraModel | 18 | from peft import LoraConfig, LoraModel |
19 | # from diffusers.models.attention_processor import AttnProcessor | ||
19 | import transformers | 20 | import transformers |
20 | 21 | ||
21 | import numpy as np | 22 | import numpy as np |
@@ -41,10 +42,11 @@ warnings.filterwarnings('ignore') | |||
41 | torch.backends.cuda.matmul.allow_tf32 = True | 42 | torch.backends.cuda.matmul.allow_tf32 = True |
42 | torch.backends.cudnn.benchmark = True | 43 | torch.backends.cudnn.benchmark = True |
43 | 44 | ||
44 | torch._dynamo.config.log_level = logging.ERROR | 45 | torch._dynamo.config.log_level = logging.WARNING |
45 | 46 | ||
46 | hidet.torch.dynamo_config.use_tensor_core(True) | 47 | hidet.torch.dynamo_config.use_tensor_core(True) |
47 | hidet.torch.dynamo_config.search_space(1) | 48 | # hidet.torch.dynamo_config.use_attention(True) |
49 | hidet.torch.dynamo_config.search_space(0) | ||
48 | 50 | ||
49 | 51 | ||
50 | def parse_args(): | 52 | def parse_args(): |
@@ -724,6 +726,19 @@ def main(): | |||
724 | if args.use_xformers: | 726 | if args.use_xformers: |
725 | vae.set_use_memory_efficient_attention_xformers(True) | 727 | vae.set_use_memory_efficient_attention_xformers(True) |
726 | unet.enable_xformers_memory_efficient_attention() | 728 | unet.enable_xformers_memory_efficient_attention() |
729 | # elif args.compile_unet: | ||
730 | # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
731 | # | ||
732 | # proc = AttnProcessor() | ||
733 | # | ||
734 | # def fn_recursive_set_proc(module: torch.nn.Module): | ||
735 | # if hasattr(module, "processor"): | ||
736 | # module.processor = proc | ||
737 | # | ||
738 | # for child in module.children(): | ||
739 | # fn_recursive_set_proc(child) | ||
740 | # | ||
741 | # fn_recursive_set_proc(unet) | ||
727 | 742 | ||
728 | if args.gradient_checkpointing: | 743 | if args.gradient_checkpointing: |
729 | unet.enable_gradient_checkpointing() | 744 | unet.enable_gradient_checkpointing() |
diff --git a/training/functional.py b/training/functional.py index 68ea40c..38dd59f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -700,8 +700,11 @@ def train( | |||
700 | vae.requires_grad_(False) | 700 | vae.requires_grad_(False) |
701 | vae.eval() | 701 | vae.eval() |
702 | 702 | ||
703 | vae = torch.compile(vae, backend='hidet') | ||
704 | |||
703 | if compile_unet: | 705 | if compile_unet: |
704 | unet = torch.compile(unet, backend='hidet') | 706 | unet = torch.compile(unet, backend='hidet') |
707 | # unet = torch.compile(unet) | ||
705 | 708 | ||
706 | callbacks = strategy.callbacks( | 709 | callbacks = strategy.callbacks( |
707 | accelerator=accelerator, | 710 | accelerator=accelerator, |