From c927eeb3516b8ecae201441a68956f32247a6d7c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Feb 2023 15:56:54 +0100 Subject: Back to xformers --- environment.yaml | 9 ++++++--- infer.py | 4 ++-- train_dreambooth.py | 4 ++-- train_lora.py | 4 ++-- train_ti.py | 5 ++--- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/environment.yaml b/environment.yaml index 8010c09..325644f 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,17 +1,18 @@ name: ldd channels: - - pytorch-nightly + - pytorch - nvidia - xformers/label/dev - defaults dependencies: - cudatoolkit=11.7 + - libcufile=1.4.0.31 - matplotlib=3.6.2 - numpy=1.23.4 - pip=22.3.1 - python=3.10.8 - - pytorch=2.0.0.dev20230216=*cuda* - - torchvision=0.15.0.dev20230216 + - pytorch=1.13.1=*cuda* + - torchvision=0.14.1 - pip: - -e . - -e git+https://github.com/huggingface/diffusers#egg=diffusers @@ -22,3 +23,5 @@ dependencies: - setuptools==65.6.3 - test-tube>=0.7.5 - transformers==4.26.1 + - triton==2.0.0a2 + - xformers==0.0.17.dev451 diff --git a/infer.py b/infer.py index 13219f8..80bd208 100644 --- a/infer.py +++ b/infer.py @@ -245,8 +245,8 @@ def create_pipeline(model, dtype): tokenizer=tokenizer, scheduler=scheduler, ) - # pipeline.enable_xformers_memory_efficient_attention() - pipeline.unet = torch.compile(pipeline.unet) + pipeline.enable_xformers_memory_efficient_attention() + # pipeline.unet = torch.compile(pipeline.unet) pipeline.enable_vae_slicing() pipeline.to("cuda") diff --git a/train_dreambooth.py b/train_dreambooth.py index 85b756c..5a7911c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -464,8 +464,8 @@ def main(): tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() - # vae.set_use_memory_efficient_attention_xformers(True) - # unet.enable_xformers_memory_efficient_attention() + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() diff --git a/train_lora.py b/train_lora.py index 8a06ae8..330bcd6 100644 --- a/train_lora.py +++ b/train_lora.py @@ -421,8 +421,8 @@ def main(): args.pretrained_model_name_or_path) vae.enable_slicing() - # vae.set_use_memory_efficient_attention_xformers(True) - # unet.enable_xformers_memory_efficient_attention() + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() unet.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) diff --git a/train_ti.py b/train_ti.py index 7d10317..3aa1027 100644 --- a/train_ti.py +++ b/train_ti.py @@ -538,9 +538,8 @@ def main(): tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() - # vae.set_use_memory_efficient_attention_xformers(True) - # unet.enable_xformers_memory_efficient_attention() - + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() # unet = torch.compile(unet) if args.gradient_checkpointing: -- cgit v1.2.3-70-g09d2