summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-28 23:51:40 +0200
committerVolpeon <git@volpeon.ink>2023-04-28 23:51:40 +0200
commit8e0e47217b7e18288eaa9462c6bbecf7387f3d89 (patch)
tree35f39eb57b55f0be6752e70110541e8c96351963
parentFixed loss/acc logging (diff)
downloadtextual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.gz
textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.bz2
textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.zip
Support torch.compile
-rw-r--r--environment.yaml14
-rw-r--r--train_lora.py10
-rw-r--r--training/functional.py12
3 files changed, 27 insertions, 9 deletions
diff --git a/environment.yaml b/environment.yaml
index a95df2a..b161244 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -4,26 +4,30 @@ channels:
4 - nvidia 4 - nvidia
5 - xformers/label/dev 5 - xformers/label/dev
6 - defaults 6 - defaults
7 - conda-forge
7dependencies: 8dependencies:
9 - gcc=11.3.0
10 - gxx=11.3.0
8 - matplotlib=3.6.2 11 - matplotlib=3.6.2
9 - numpy=1.23.4 12 - numpy=1.23.4
10 - pip=22.3.1 13 - pip=22.3.1
11 - python=3.10.8 14 - python=3.10.8
12 - pytorch=2.0.0=*cuda11.8* 15 - pytorch=2.0.0=*cuda11.8*
13 - torchvision=0.15.0 16 - torchvision=0.15.0
14 - xformers=0.0.18.dev504 17 - xformers=0.0.19
15 - pip: 18 - pip:
16 - -e . 19 - -e .
20 - -e git+https://github.com/huggingface/accelerate#egg=accelerate
17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 21 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
18 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation 22 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation
19 - accelerate==0.17.1 23 - bitsandbytes==0.38.1
20 - bitsandbytes==0.37.2 24 - hidet==0.2.3
21 - lion-pytorch==0.0.7 25 - lion-pytorch==0.0.7
22 - peft==0.2.0 26 - peft==0.2.0
23 - python-slugify>=6.1.2 27 - python-slugify>=6.1.2
24 - safetensors==0.3.0 28 - safetensors==0.3.1
25 - setuptools==65.6.3 29 - setuptools==65.6.3
26 - test-tube>=0.7.5 30 - test-tube>=0.7.5
27 - timm==0.8.17.dev0 31 - timm==0.8.17.dev0
28 - transformers==4.27.1 32 - transformers==4.28.1
29 - triton==2.0.0.post1 33 - triton==2.0.0.post1
diff --git a/train_lora.py b/train_lora.py
index d5aa78d..64346bc 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -5,9 +5,12 @@ import itertools
5from pathlib import Path 5from pathlib import Path
6from functools import partial 6from functools import partial
7import math 7import math
8import warnings
8 9
9import torch 10import torch
11import torch._dynamo
10import torch.utils.checkpoint 12import torch.utils.checkpoint
13import hidet
11 14
12from accelerate import Accelerator 15from accelerate import Accelerator
13from accelerate.logging import get_logger 16from accelerate.logging import get_logger
@@ -32,10 +35,17 @@ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
32 35
33logger = get_logger(__name__) 36logger = get_logger(__name__)
34 37
38warnings.filterwarnings('ignore')
39
35 40
36torch.backends.cuda.matmul.allow_tf32 = True 41torch.backends.cuda.matmul.allow_tf32 = True
37torch.backends.cudnn.benchmark = True 42torch.backends.cudnn.benchmark = True
38 43
44torch._dynamo.config.log_level = logging.ERROR
45
46hidet.torch.dynamo_config.use_tensor_core(True)
47hidet.torch.dynamo_config.search_space(2)
48
39 49
40def parse_args(): 50def parse_args():
41 parser = argparse.ArgumentParser( 51 parser = argparse.ArgumentParser(
diff --git a/training/functional.py b/training/functional.py
index 6ae35a0..e7cc20f 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -624,24 +624,24 @@ def train_loop(
624 accelerator.log(logs, step=global_step) 624 accelerator.log(logs, step=global_step)
625 625
626 if accelerator.is_main_process: 626 if accelerator.is_main_process:
627 if avg_acc_val.avg > best_acc_val and milestone_checkpoints: 627 if avg_acc_val.max > best_acc_val and milestone_checkpoints:
628 local_progress_bar.clear() 628 local_progress_bar.clear()
629 global_progress_bar.clear() 629 global_progress_bar.clear()
630 630
631 accelerator.print( 631 accelerator.print(
632 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") 632 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}")
633 on_checkpoint(global_step, "milestone") 633 on_checkpoint(global_step, "milestone")
634 best_acc_val = avg_acc_val.avg 634 best_acc_val = avg_acc_val.max
635 else: 635 else:
636 if accelerator.is_main_process: 636 if accelerator.is_main_process:
637 if avg_acc.avg > best_acc and milestone_checkpoints: 637 if avg_acc.max > best_acc and milestone_checkpoints:
638 local_progress_bar.clear() 638 local_progress_bar.clear()
639 global_progress_bar.clear() 639 global_progress_bar.clear()
640 640
641 accelerator.print( 641 accelerator.print(
642 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") 642 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}")
643 on_checkpoint(global_step, "milestone") 643 on_checkpoint(global_step, "milestone")
644 best_acc = avg_acc.avg 644 best_acc = avg_acc.max
645 645
646 # Create the pipeline using using the trained modules and save it. 646 # Create the pipeline using using the trained modules and save it.
647 if accelerator.is_main_process: 647 if accelerator.is_main_process:
@@ -699,6 +699,10 @@ def train(
699 vae.requires_grad_(False) 699 vae.requires_grad_(False)
700 vae.eval() 700 vae.eval()
701 701
702 unet = torch.compile(unet)
703 text_encoder = torch.compile(text_encoder)
704 vae = torch.compile(vae)
705
702 callbacks = strategy.callbacks( 706 callbacks = strategy.callbacks(
703 accelerator=accelerator, 707 accelerator=accelerator,
704 unet=unet, 708 unet=unet,