diff options
-rw-r--r-- | environment.yaml | 14 | ||||
-rw-r--r-- | train_lora.py | 10 | ||||
-rw-r--r-- | training/functional.py | 12 |
3 files changed, 27 insertions, 9 deletions
diff --git a/environment.yaml b/environment.yaml index a95df2a..b161244 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -4,26 +4,30 @@ channels: | |||
4 | - nvidia | 4 | - nvidia |
5 | - xformers/label/dev | 5 | - xformers/label/dev |
6 | - defaults | 6 | - defaults |
7 | - conda-forge | ||
7 | dependencies: | 8 | dependencies: |
9 | - gcc=11.3.0 | ||
10 | - gxx=11.3.0 | ||
8 | - matplotlib=3.6.2 | 11 | - matplotlib=3.6.2 |
9 | - numpy=1.23.4 | 12 | - numpy=1.23.4 |
10 | - pip=22.3.1 | 13 | - pip=22.3.1 |
11 | - python=3.10.8 | 14 | - python=3.10.8 |
12 | - pytorch=2.0.0=*cuda11.8* | 15 | - pytorch=2.0.0=*cuda11.8* |
13 | - torchvision=0.15.0 | 16 | - torchvision=0.15.0 |
14 | - xformers=0.0.18.dev504 | 17 | - xformers=0.0.19 |
15 | - pip: | 18 | - pip: |
16 | - -e . | 19 | - -e . |
20 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate | ||
17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 21 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
18 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | 22 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation |
19 | - accelerate==0.17.1 | 23 | - bitsandbytes==0.38.1 |
20 | - bitsandbytes==0.37.2 | 24 | - hidet==0.2.3 |
21 | - lion-pytorch==0.0.7 | 25 | - lion-pytorch==0.0.7 |
22 | - peft==0.2.0 | 26 | - peft==0.2.0 |
23 | - python-slugify>=6.1.2 | 27 | - python-slugify>=6.1.2 |
24 | - safetensors==0.3.0 | 28 | - safetensors==0.3.1 |
25 | - setuptools==65.6.3 | 29 | - setuptools==65.6.3 |
26 | - test-tube>=0.7.5 | 30 | - test-tube>=0.7.5 |
27 | - timm==0.8.17.dev0 | 31 | - timm==0.8.17.dev0 |
28 | - transformers==4.27.1 | 32 | - transformers==4.28.1 |
29 | - triton==2.0.0.post1 | 33 | - triton==2.0.0.post1 |
diff --git a/train_lora.py b/train_lora.py index d5aa78d..64346bc 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -5,9 +5,12 @@ import itertools | |||
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from functools import partial | 6 | from functools import partial |
7 | import math | 7 | import math |
8 | import warnings | ||
8 | 9 | ||
9 | import torch | 10 | import torch |
11 | import torch._dynamo | ||
10 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
13 | import hidet | ||
11 | 14 | ||
12 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
13 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
@@ -32,10 +35,17 @@ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | |||
32 | 35 | ||
33 | logger = get_logger(__name__) | 36 | logger = get_logger(__name__) |
34 | 37 | ||
38 | warnings.filterwarnings('ignore') | ||
39 | |||
35 | 40 | ||
36 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
37 | torch.backends.cudnn.benchmark = True | 42 | torch.backends.cudnn.benchmark = True |
38 | 43 | ||
44 | torch._dynamo.config.log_level = logging.ERROR | ||
45 | |||
46 | hidet.torch.dynamo_config.use_tensor_core(True) | ||
47 | hidet.torch.dynamo_config.search_space(2) | ||
48 | |||
39 | 49 | ||
40 | def parse_args(): | 50 | def parse_args(): |
41 | parser = argparse.ArgumentParser( | 51 | parser = argparse.ArgumentParser( |
diff --git a/training/functional.py b/training/functional.py index 6ae35a0..e7cc20f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -624,24 +624,24 @@ def train_loop( | |||
624 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
625 | 625 | ||
626 | if accelerator.is_main_process: | 626 | if accelerator.is_main_process: |
627 | if avg_acc_val.avg > best_acc_val and milestone_checkpoints: | 627 | if avg_acc_val.max > best_acc_val and milestone_checkpoints: |
628 | local_progress_bar.clear() | 628 | local_progress_bar.clear() |
629 | global_progress_bar.clear() | 629 | global_progress_bar.clear() |
630 | 630 | ||
631 | accelerator.print( | 631 | accelerator.print( |
632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") | 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") |
633 | on_checkpoint(global_step, "milestone") | 633 | on_checkpoint(global_step, "milestone") |
634 | best_acc_val = avg_acc_val.avg | 634 | best_acc_val = avg_acc_val.max |
635 | else: | 635 | else: |
636 | if accelerator.is_main_process: | 636 | if accelerator.is_main_process: |
637 | if avg_acc.avg > best_acc and milestone_checkpoints: | 637 | if avg_acc.max > best_acc and milestone_checkpoints: |
638 | local_progress_bar.clear() | 638 | local_progress_bar.clear() |
639 | global_progress_bar.clear() | 639 | global_progress_bar.clear() |
640 | 640 | ||
641 | accelerator.print( | 641 | accelerator.print( |
642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") | 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") |
643 | on_checkpoint(global_step, "milestone") | 643 | on_checkpoint(global_step, "milestone") |
644 | best_acc = avg_acc.avg | 644 | best_acc = avg_acc.max |
645 | 645 | ||
646 | # Create the pipeline using using the trained modules and save it. | 646 | # Create the pipeline using using the trained modules and save it. |
647 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |
@@ -699,6 +699,10 @@ def train( | |||
699 | vae.requires_grad_(False) | 699 | vae.requires_grad_(False) |
700 | vae.eval() | 700 | vae.eval() |
701 | 701 | ||
702 | unet = torch.compile(unet) | ||
703 | text_encoder = torch.compile(text_encoder) | ||
704 | vae = torch.compile(vae) | ||
705 | |||
702 | callbacks = strategy.callbacks( | 706 | callbacks = strategy.callbacks( |
703 | accelerator=accelerator, | 707 | accelerator=accelerator, |
704 | unet=unet, | 708 | unet=unet, |