summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml3
-rw-r--r--environment_nightly.yaml31
-rw-r--r--train_lora.py19
-rw-r--r--training/functional.py3
4 files changed, 52 insertions, 4 deletions
diff --git a/environment.yaml b/environment.yaml
index b161244..dfbafaf 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -4,7 +4,7 @@ channels:
4 - nvidia 4 - nvidia
5 - xformers/label/dev 5 - xformers/label/dev
6 - defaults 6 - defaults
7 - conda-forge 7 - conda-forge
8dependencies: 8dependencies:
9 - gcc=11.3.0 9 - gcc=11.3.0
10 - gxx=11.3.0 10 - gxx=11.3.0
@@ -30,4 +30,3 @@ dependencies:
30 - test-tube>=0.7.5 30 - test-tube>=0.7.5
31 - timm==0.8.17.dev0 31 - timm==0.8.17.dev0
32 - transformers==4.28.1 32 - transformers==4.28.1
33 - triton==2.0.0.post1
diff --git a/environment_nightly.yaml b/environment_nightly.yaml
new file mode 100644
index 0000000..4c5c798
--- /dev/null
+++ b/environment_nightly.yaml
@@ -0,0 +1,31 @@
1name: ldd
2channels:
3 - pytorch-nightly
4 - nvidia
5 - xformers/label/dev
6 - defaults
7 - conda-forge
8dependencies:
9 - cuda-nvcc=12.1.105
10 - matplotlib=3.6.2
11 - numpy=1.24.3
12 - pip=22.3.1
13 - python=3.10.8
14 - pytorch=2.1.0.dev20230429=*cuda12.1*
15 - torchvision=0.16.0.dev20230429
16 # - xformers=0.0.19
17 - pip:
18 - -e .
19 - -e git+https://github.com/huggingface/accelerate#egg=accelerate
20 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
21 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation
22 - bitsandbytes==0.38.1
23 - hidet==0.2.3
24 - lion-pytorch==0.0.7
25 - peft==0.2.0
26 - python-slugify>=6.1.2
27 - safetensors==0.3.1
28 - setuptools==65.6.3
29 - test-tube>=0.7.5
30 - timm==0.8.17.dev0
31 - transformers==4.28.1
diff --git a/train_lora.py b/train_lora.py
index d95dbb9..3c8fc97 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -16,6 +16,7 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from peft import LoraConfig, LoraModel 18from peft import LoraConfig, LoraModel
19# from diffusers.models.attention_processor import AttnProcessor
19import transformers 20import transformers
20 21
21import numpy as np 22import numpy as np
@@ -41,10 +42,11 @@ warnings.filterwarnings('ignore')
41torch.backends.cuda.matmul.allow_tf32 = True 42torch.backends.cuda.matmul.allow_tf32 = True
42torch.backends.cudnn.benchmark = True 43torch.backends.cudnn.benchmark = True
43 44
44torch._dynamo.config.log_level = logging.ERROR 45torch._dynamo.config.log_level = logging.WARNING
45 46
46hidet.torch.dynamo_config.use_tensor_core(True) 47hidet.torch.dynamo_config.use_tensor_core(True)
47hidet.torch.dynamo_config.search_space(1) 48# hidet.torch.dynamo_config.use_attention(True)
49hidet.torch.dynamo_config.search_space(0)
48 50
49 51
50def parse_args(): 52def parse_args():
@@ -724,6 +726,19 @@ def main():
724 if args.use_xformers: 726 if args.use_xformers:
725 vae.set_use_memory_efficient_attention_xformers(True) 727 vae.set_use_memory_efficient_attention_xformers(True)
726 unet.enable_xformers_memory_efficient_attention() 728 unet.enable_xformers_memory_efficient_attention()
729 # elif args.compile_unet:
730 # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False
731 #
732 # proc = AttnProcessor()
733 #
734 # def fn_recursive_set_proc(module: torch.nn.Module):
735 # if hasattr(module, "processor"):
736 # module.processor = proc
737 #
738 # for child in module.children():
739 # fn_recursive_set_proc(child)
740 #
741 # fn_recursive_set_proc(unet)
727 742
728 if args.gradient_checkpointing: 743 if args.gradient_checkpointing:
729 unet.enable_gradient_checkpointing() 744 unet.enable_gradient_checkpointing()
diff --git a/training/functional.py b/training/functional.py
index 68ea40c..38dd59f 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -700,8 +700,11 @@ def train(
700 vae.requires_grad_(False) 700 vae.requires_grad_(False)
701 vae.eval() 701 vae.eval()
702 702
703 vae = torch.compile(vae, backend='hidet')
704
703 if compile_unet: 705 if compile_unet:
704 unet = torch.compile(unet, backend='hidet') 706 unet = torch.compile(unet, backend='hidet')
707 # unet = torch.compile(unet)
705 708
706 callbacks = strategy.callbacks( 709 callbacks = strategy.callbacks(
707 accelerator=accelerator, 710 accelerator=accelerator,