From 74a5974ba30c170198890e59c92463bf5319fe64 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 29 Apr 2023 16:35:41 +0200 Subject: torch.compile won't work yet, keep code prepared --- environment.yaml | 3 +-- environment_nightly.yaml | 31 +++++++++++++++++++++++++++++++ train_lora.py | 19 +++++++++++++++++-- training/functional.py | 3 +++ 4 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 environment_nightly.yaml diff --git a/environment.yaml b/environment.yaml index b161244..dfbafaf 100644 --- a/environment.yaml +++ b/environment.yaml @@ -4,7 +4,7 @@ channels: - nvidia - xformers/label/dev - defaults - - conda-forge + - conda-forge dependencies: - gcc=11.3.0 - gxx=11.3.0 @@ -30,4 +30,3 @@ dependencies: - test-tube>=0.7.5 - timm==0.8.17.dev0 - transformers==4.28.1 - - triton==2.0.0.post1 diff --git a/environment_nightly.yaml b/environment_nightly.yaml new file mode 100644 index 0000000..4c5c798 --- /dev/null +++ b/environment_nightly.yaml @@ -0,0 +1,31 @@ +name: ldd +channels: + - pytorch-nightly + - nvidia + - xformers/label/dev + - defaults + - conda-forge +dependencies: + - cuda-nvcc=12.1.105 + - matplotlib=3.6.2 + - numpy=1.24.3 + - pip=22.3.1 + - python=3.10.8 + - pytorch=2.1.0.dev20230429=*cuda12.1* + - torchvision=0.16.0.dev20230429 + # - xformers=0.0.19 + - pip: + - -e . + - -e git+https://github.com/huggingface/accelerate#egg=accelerate + - -e git+https://github.com/huggingface/diffusers#egg=diffusers + - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation + - bitsandbytes==0.38.1 + - hidet==0.2.3 + - lion-pytorch==0.0.7 + - peft==0.2.0 + - python-slugify>=6.1.2 + - safetensors==0.3.1 + - setuptools==65.6.3 + - test-tube>=0.7.5 + - timm==0.8.17.dev0 + - transformers==4.28.1 diff --git a/train_lora.py b/train_lora.py index d95dbb9..3c8fc97 100644 --- a/train_lora.py +++ b/train_lora.py @@ -16,6 +16,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from peft import LoraConfig, LoraModel +# from diffusers.models.attention_processor import AttnProcessor import transformers import numpy as np @@ -41,10 +42,11 @@ warnings.filterwarnings('ignore') torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True -torch._dynamo.config.log_level = logging.ERROR +torch._dynamo.config.log_level = logging.WARNING hidet.torch.dynamo_config.use_tensor_core(True) -hidet.torch.dynamo_config.search_space(1) +# hidet.torch.dynamo_config.use_attention(True) +hidet.torch.dynamo_config.search_space(0) def parse_args(): @@ -724,6 +726,19 @@ def main(): if args.use_xformers: vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() + # elif args.compile_unet: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False + # + # proc = AttnProcessor() + # + # def fn_recursive_set_proc(module: torch.nn.Module): + # if hasattr(module, "processor"): + # module.processor = proc + # + # for child in module.children(): + # fn_recursive_set_proc(child) + # + # fn_recursive_set_proc(unet) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() diff --git a/training/functional.py b/training/functional.py index 68ea40c..38dd59f 100644 --- a/training/functional.py +++ b/training/functional.py @@ -700,8 +700,11 @@ def train( vae.requires_grad_(False) vae.eval() + vae = torch.compile(vae, backend='hidet') + if compile_unet: unet = torch.compile(unet, backend='hidet') + # unet = torch.compile(unet) callbacks = strategy.callbacks( accelerator=accelerator, -- cgit v1.2.3-70-g09d2