From 8e0e47217b7e18288eaa9462c6bbecf7387f3d89 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 28 Apr 2023 23:51:40 +0200 Subject: Support torch.compile --- environment.yaml | 14 +++++++++----- train_lora.py | 10 ++++++++++ training/functional.py | 12 ++++++++---- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/environment.yaml b/environment.yaml index a95df2a..b161244 100644 --- a/environment.yaml +++ b/environment.yaml @@ -4,26 +4,30 @@ channels: - nvidia - xformers/label/dev - defaults + - conda-forge dependencies: + - gcc=11.3.0 + - gxx=11.3.0 - matplotlib=3.6.2 - numpy=1.23.4 - pip=22.3.1 - python=3.10.8 - pytorch=2.0.0=*cuda11.8* - torchvision=0.15.0 - - xformers=0.0.18.dev504 + - xformers=0.0.19 - pip: - -e . + - -e git+https://github.com/huggingface/accelerate#egg=accelerate - -e git+https://github.com/huggingface/diffusers#egg=diffusers - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation - - accelerate==0.17.1 - - bitsandbytes==0.37.2 + - bitsandbytes==0.38.1 + - hidet==0.2.3 - lion-pytorch==0.0.7 - peft==0.2.0 - python-slugify>=6.1.2 - - safetensors==0.3.0 + - safetensors==0.3.1 - setuptools==65.6.3 - test-tube>=0.7.5 - timm==0.8.17.dev0 - - transformers==4.27.1 + - transformers==4.28.1 - triton==2.0.0.post1 diff --git a/train_lora.py b/train_lora.py index d5aa78d..64346bc 100644 --- a/train_lora.py +++ b/train_lora.py @@ -5,9 +5,12 @@ import itertools from pathlib import Path from functools import partial import math +import warnings import torch +import torch._dynamo import torch.utils.checkpoint +import hidet from accelerate import Accelerator from accelerate.logging import get_logger @@ -32,10 +35,17 @@ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] logger = get_logger(__name__) +warnings.filterwarnings('ignore') + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True +torch._dynamo.config.log_level = logging.ERROR + +hidet.torch.dynamo_config.use_tensor_core(True) +hidet.torch.dynamo_config.search_space(2) + def parse_args(): parser = argparse.ArgumentParser( diff --git a/training/functional.py b/training/functional.py index 6ae35a0..e7cc20f 100644 --- a/training/functional.py +++ b/training/functional.py @@ -624,24 +624,24 @@ def train_loop( accelerator.log(logs, step=global_step) if accelerator.is_main_process: - if avg_acc_val.avg > best_acc_val and milestone_checkpoints: + if avg_acc_val.max > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc_val = avg_acc_val.avg + best_acc_val = avg_acc_val.max else: if accelerator.is_main_process: - if avg_acc.avg > best_acc and milestone_checkpoints: + if avg_acc.max > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc = avg_acc.avg + best_acc = avg_acc.max # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: @@ -699,6 +699,10 @@ def train( vae.requires_grad_(False) vae.eval() + unet = torch.compile(unet) + text_encoder = torch.compile(text_encoder) + vae = torch.compile(vae) + callbacks = strategy.callbacks( accelerator=accelerator, unet=unet, -- cgit v1.2.3-70-g09d2