diff options
| -rw-r--r-- | environment.yaml | 14 | ||||
| -rw-r--r-- | train_lora.py | 10 | ||||
| -rw-r--r-- | training/functional.py | 12 |
3 files changed, 27 insertions, 9 deletions
diff --git a/environment.yaml b/environment.yaml index a95df2a..b161244 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -4,26 +4,30 @@ channels: | |||
| 4 | - nvidia | 4 | - nvidia |
| 5 | - xformers/label/dev | 5 | - xformers/label/dev |
| 6 | - defaults | 6 | - defaults |
| 7 | - conda-forge | ||
| 7 | dependencies: | 8 | dependencies: |
| 9 | - gcc=11.3.0 | ||
| 10 | - gxx=11.3.0 | ||
| 8 | - matplotlib=3.6.2 | 11 | - matplotlib=3.6.2 |
| 9 | - numpy=1.23.4 | 12 | - numpy=1.23.4 |
| 10 | - pip=22.3.1 | 13 | - pip=22.3.1 |
| 11 | - python=3.10.8 | 14 | - python=3.10.8 |
| 12 | - pytorch=2.0.0=*cuda11.8* | 15 | - pytorch=2.0.0=*cuda11.8* |
| 13 | - torchvision=0.15.0 | 16 | - torchvision=0.15.0 |
| 14 | - xformers=0.0.18.dev504 | 17 | - xformers=0.0.19 |
| 15 | - pip: | 18 | - pip: |
| 16 | - -e . | 19 | - -e . |
| 20 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate | ||
| 17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 21 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 18 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | 22 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation |
| 19 | - accelerate==0.17.1 | 23 | - bitsandbytes==0.38.1 |
| 20 | - bitsandbytes==0.37.2 | 24 | - hidet==0.2.3 |
| 21 | - lion-pytorch==0.0.7 | 25 | - lion-pytorch==0.0.7 |
| 22 | - peft==0.2.0 | 26 | - peft==0.2.0 |
| 23 | - python-slugify>=6.1.2 | 27 | - python-slugify>=6.1.2 |
| 24 | - safetensors==0.3.0 | 28 | - safetensors==0.3.1 |
| 25 | - setuptools==65.6.3 | 29 | - setuptools==65.6.3 |
| 26 | - test-tube>=0.7.5 | 30 | - test-tube>=0.7.5 |
| 27 | - timm==0.8.17.dev0 | 31 | - timm==0.8.17.dev0 |
| 28 | - transformers==4.27.1 | 32 | - transformers==4.28.1 |
| 29 | - triton==2.0.0.post1 | 33 | - triton==2.0.0.post1 |
diff --git a/train_lora.py b/train_lora.py index d5aa78d..64346bc 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -5,9 +5,12 @@ import itertools | |||
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from functools import partial | 6 | from functools import partial |
| 7 | import math | 7 | import math |
| 8 | import warnings | ||
| 8 | 9 | ||
| 9 | import torch | 10 | import torch |
| 11 | import torch._dynamo | ||
| 10 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
| 13 | import hidet | ||
| 11 | 14 | ||
| 12 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 13 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| @@ -32,10 +35,17 @@ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | |||
| 32 | 35 | ||
| 33 | logger = get_logger(__name__) | 36 | logger = get_logger(__name__) |
| 34 | 37 | ||
| 38 | warnings.filterwarnings('ignore') | ||
| 39 | |||
| 35 | 40 | ||
| 36 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
| 37 | torch.backends.cudnn.benchmark = True | 42 | torch.backends.cudnn.benchmark = True |
| 38 | 43 | ||
| 44 | torch._dynamo.config.log_level = logging.ERROR | ||
| 45 | |||
| 46 | hidet.torch.dynamo_config.use_tensor_core(True) | ||
| 47 | hidet.torch.dynamo_config.search_space(2) | ||
| 48 | |||
| 39 | 49 | ||
| 40 | def parse_args(): | 50 | def parse_args(): |
| 41 | parser = argparse.ArgumentParser( | 51 | parser = argparse.ArgumentParser( |
diff --git a/training/functional.py b/training/functional.py index 6ae35a0..e7cc20f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -624,24 +624,24 @@ def train_loop( | |||
| 624 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
| 625 | 625 | ||
| 626 | if accelerator.is_main_process: | 626 | if accelerator.is_main_process: |
| 627 | if avg_acc_val.avg > best_acc_val and milestone_checkpoints: | 627 | if avg_acc_val.max > best_acc_val and milestone_checkpoints: |
| 628 | local_progress_bar.clear() | 628 | local_progress_bar.clear() |
| 629 | global_progress_bar.clear() | 629 | global_progress_bar.clear() |
| 630 | 630 | ||
| 631 | accelerator.print( | 631 | accelerator.print( |
| 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") | 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") |
| 633 | on_checkpoint(global_step, "milestone") | 633 | on_checkpoint(global_step, "milestone") |
| 634 | best_acc_val = avg_acc_val.avg | 634 | best_acc_val = avg_acc_val.max |
| 635 | else: | 635 | else: |
| 636 | if accelerator.is_main_process: | 636 | if accelerator.is_main_process: |
| 637 | if avg_acc.avg > best_acc and milestone_checkpoints: | 637 | if avg_acc.max > best_acc and milestone_checkpoints: |
| 638 | local_progress_bar.clear() | 638 | local_progress_bar.clear() |
| 639 | global_progress_bar.clear() | 639 | global_progress_bar.clear() |
| 640 | 640 | ||
| 641 | accelerator.print( | 641 | accelerator.print( |
| 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") | 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") |
| 643 | on_checkpoint(global_step, "milestone") | 643 | on_checkpoint(global_step, "milestone") |
| 644 | best_acc = avg_acc.avg | 644 | best_acc = avg_acc.max |
| 645 | 645 | ||
| 646 | # Create the pipeline using using the trained modules and save it. | 646 | # Create the pipeline using using the trained modules and save it. |
| 647 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |
| @@ -699,6 +699,10 @@ def train( | |||
| 699 | vae.requires_grad_(False) | 699 | vae.requires_grad_(False) |
| 700 | vae.eval() | 700 | vae.eval() |
| 701 | 701 | ||
| 702 | unet = torch.compile(unet) | ||
| 703 | text_encoder = torch.compile(text_encoder) | ||
| 704 | vae = torch.compile(vae) | ||
| 705 | |||
| 702 | callbacks = strategy.callbacks( | 706 | callbacks = strategy.callbacks( |
| 703 | accelerator=accelerator, | 707 | accelerator=accelerator, |
| 704 | unet=unet, | 708 | unet=unet, |
