diff options
author | Volpeon <git@volpeon.ink> | 2023-04-28 23:51:40 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-28 23:51:40 +0200 |
commit | 8e0e47217b7e18288eaa9462c6bbecf7387f3d89 (patch) | |
tree | 35f39eb57b55f0be6752e70110541e8c96351963 /train_lora.py | |
parent | Fixed loss/acc logging (diff) | |
download | textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.gz textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.bz2 textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.zip |
Support torch.compile
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py index d5aa78d..64346bc 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -5,9 +5,12 @@ import itertools | |||
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from functools import partial | 6 | from functools import partial |
7 | import math | 7 | import math |
8 | import warnings | ||
8 | 9 | ||
9 | import torch | 10 | import torch |
11 | import torch._dynamo | ||
10 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
13 | import hidet | ||
11 | 14 | ||
12 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
13 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
@@ -32,10 +35,17 @@ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | |||
32 | 35 | ||
33 | logger = get_logger(__name__) | 36 | logger = get_logger(__name__) |
34 | 37 | ||
38 | warnings.filterwarnings('ignore') | ||
39 | |||
35 | 40 | ||
36 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
37 | torch.backends.cudnn.benchmark = True | 42 | torch.backends.cudnn.benchmark = True |
38 | 43 | ||
44 | torch._dynamo.config.log_level = logging.ERROR | ||
45 | |||
46 | hidet.torch.dynamo_config.use_tensor_core(True) | ||
47 | hidet.torch.dynamo_config.search_space(2) | ||
48 | |||
39 | 49 | ||
40 | def parse_args(): | 50 | def parse_args(): |
41 | parser = argparse.ArgumentParser( | 51 | parser = argparse.ArgumentParser( |