diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-29 16:35:41 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-29 16:35:41 +0200 |
| commit | 74a5974ba30c170198890e59c92463bf5319fe64 (patch) | |
| tree | 63637875e8c0a8707b0c413e3b2bbccad33f4db5 | |
| parent | Optional xformers (diff) | |
| download | textual-inversion-diff-74a5974ba30c170198890e59c92463bf5319fe64.tar.gz textual-inversion-diff-74a5974ba30c170198890e59c92463bf5319fe64.tar.bz2 textual-inversion-diff-74a5974ba30c170198890e59c92463bf5319fe64.zip | |
torch.compile won't work yet, keep code prepared
| -rw-r--r-- | environment.yaml | 3 | ||||
| -rw-r--r-- | environment_nightly.yaml | 31 | ||||
| -rw-r--r-- | train_lora.py | 19 | ||||
| -rw-r--r-- | training/functional.py | 3 |
4 files changed, 52 insertions, 4 deletions
diff --git a/environment.yaml b/environment.yaml index b161244..dfbafaf 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -4,7 +4,7 @@ channels: | |||
| 4 | - nvidia | 4 | - nvidia |
| 5 | - xformers/label/dev | 5 | - xformers/label/dev |
| 6 | - defaults | 6 | - defaults |
| 7 | - conda-forge | 7 | - conda-forge |
| 8 | dependencies: | 8 | dependencies: |
| 9 | - gcc=11.3.0 | 9 | - gcc=11.3.0 |
| 10 | - gxx=11.3.0 | 10 | - gxx=11.3.0 |
| @@ -30,4 +30,3 @@ dependencies: | |||
| 30 | - test-tube>=0.7.5 | 30 | - test-tube>=0.7.5 |
| 31 | - timm==0.8.17.dev0 | 31 | - timm==0.8.17.dev0 |
| 32 | - transformers==4.28.1 | 32 | - transformers==4.28.1 |
| 33 | - triton==2.0.0.post1 | ||
diff --git a/environment_nightly.yaml b/environment_nightly.yaml new file mode 100644 index 0000000..4c5c798 --- /dev/null +++ b/environment_nightly.yaml | |||
| @@ -0,0 +1,31 @@ | |||
| 1 | name: ldd | ||
| 2 | channels: | ||
| 3 | - pytorch-nightly | ||
| 4 | - nvidia | ||
| 5 | - xformers/label/dev | ||
| 6 | - defaults | ||
| 7 | - conda-forge | ||
| 8 | dependencies: | ||
| 9 | - cuda-nvcc=12.1.105 | ||
| 10 | - matplotlib=3.6.2 | ||
| 11 | - numpy=1.24.3 | ||
| 12 | - pip=22.3.1 | ||
| 13 | - python=3.10.8 | ||
| 14 | - pytorch=2.1.0.dev20230429=*cuda12.1* | ||
| 15 | - torchvision=0.16.0.dev20230429 | ||
| 16 | # - xformers=0.0.19 | ||
| 17 | - pip: | ||
| 18 | - -e . | ||
| 19 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate | ||
| 20 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | ||
| 21 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | ||
| 22 | - bitsandbytes==0.38.1 | ||
| 23 | - hidet==0.2.3 | ||
| 24 | - lion-pytorch==0.0.7 | ||
| 25 | - peft==0.2.0 | ||
| 26 | - python-slugify>=6.1.2 | ||
| 27 | - safetensors==0.3.1 | ||
| 28 | - setuptools==65.6.3 | ||
| 29 | - test-tube>=0.7.5 | ||
| 30 | - timm==0.8.17.dev0 | ||
| 31 | - transformers==4.28.1 | ||
diff --git a/train_lora.py b/train_lora.py index d95dbb9..3c8fc97 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -16,6 +16,7 @@ from accelerate import Accelerator | |||
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from peft import LoraConfig, LoraModel | 18 | from peft import LoraConfig, LoraModel |
| 19 | # from diffusers.models.attention_processor import AttnProcessor | ||
| 19 | import transformers | 20 | import transformers |
| 20 | 21 | ||
| 21 | import numpy as np | 22 | import numpy as np |
| @@ -41,10 +42,11 @@ warnings.filterwarnings('ignore') | |||
| 41 | torch.backends.cuda.matmul.allow_tf32 = True | 42 | torch.backends.cuda.matmul.allow_tf32 = True |
| 42 | torch.backends.cudnn.benchmark = True | 43 | torch.backends.cudnn.benchmark = True |
| 43 | 44 | ||
| 44 | torch._dynamo.config.log_level = logging.ERROR | 45 | torch._dynamo.config.log_level = logging.WARNING |
| 45 | 46 | ||
| 46 | hidet.torch.dynamo_config.use_tensor_core(True) | 47 | hidet.torch.dynamo_config.use_tensor_core(True) |
| 47 | hidet.torch.dynamo_config.search_space(1) | 48 | # hidet.torch.dynamo_config.use_attention(True) |
| 49 | hidet.torch.dynamo_config.search_space(0) | ||
| 48 | 50 | ||
| 49 | 51 | ||
| 50 | def parse_args(): | 52 | def parse_args(): |
| @@ -724,6 +726,19 @@ def main(): | |||
| 724 | if args.use_xformers: | 726 | if args.use_xformers: |
| 725 | vae.set_use_memory_efficient_attention_xformers(True) | 727 | vae.set_use_memory_efficient_attention_xformers(True) |
| 726 | unet.enable_xformers_memory_efficient_attention() | 728 | unet.enable_xformers_memory_efficient_attention() |
| 729 | # elif args.compile_unet: | ||
| 730 | # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
| 731 | # | ||
| 732 | # proc = AttnProcessor() | ||
| 733 | # | ||
| 734 | # def fn_recursive_set_proc(module: torch.nn.Module): | ||
| 735 | # if hasattr(module, "processor"): | ||
| 736 | # module.processor = proc | ||
| 737 | # | ||
| 738 | # for child in module.children(): | ||
| 739 | # fn_recursive_set_proc(child) | ||
| 740 | # | ||
| 741 | # fn_recursive_set_proc(unet) | ||
| 727 | 742 | ||
| 728 | if args.gradient_checkpointing: | 743 | if args.gradient_checkpointing: |
| 729 | unet.enable_gradient_checkpointing() | 744 | unet.enable_gradient_checkpointing() |
diff --git a/training/functional.py b/training/functional.py index 68ea40c..38dd59f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -700,8 +700,11 @@ def train( | |||
| 700 | vae.requires_grad_(False) | 700 | vae.requires_grad_(False) |
| 701 | vae.eval() | 701 | vae.eval() |
| 702 | 702 | ||
| 703 | vae = torch.compile(vae, backend='hidet') | ||
| 704 | |||
| 703 | if compile_unet: | 705 | if compile_unet: |
| 704 | unet = torch.compile(unet, backend='hidet') | 706 | unet = torch.compile(unet, backend='hidet') |
| 707 | # unet = torch.compile(unet) | ||
| 705 | 708 | ||
| 706 | callbacks = strategy.callbacks( | 709 | callbacks = strategy.callbacks( |
| 707 | accelerator=accelerator, | 710 | accelerator=accelerator, |
