summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-28 23:51:40 +0200
committerVolpeon <git@volpeon.ink>2023-04-28 23:51:40 +0200
commit8e0e47217b7e18288eaa9462c6bbecf7387f3d89 (patch)
tree35f39eb57b55f0be6752e70110541e8c96351963 /train_lora.py
parentFixed loss/acc logging (diff)
downloadtextual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.gz
textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.tar.bz2
textual-inversion-diff-8e0e47217b7e18288eaa9462c6bbecf7387f3d89.zip
Support torch.compile
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py
index d5aa78d..64346bc 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -5,9 +5,12 @@ import itertools
5from pathlib import Path 5from pathlib import Path
6from functools import partial 6from functools import partial
7import math 7import math
8import warnings
8 9
9import torch 10import torch
11import torch._dynamo
10import torch.utils.checkpoint 12import torch.utils.checkpoint
13import hidet
11 14
12from accelerate import Accelerator 15from accelerate import Accelerator
13from accelerate.logging import get_logger 16from accelerate.logging import get_logger
@@ -32,10 +35,17 @@ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
32 35
33logger = get_logger(__name__) 36logger = get_logger(__name__)
34 37
38warnings.filterwarnings('ignore')
39
35 40
36torch.backends.cuda.matmul.allow_tf32 = True 41torch.backends.cuda.matmul.allow_tf32 = True
37torch.backends.cudnn.benchmark = True 42torch.backends.cudnn.benchmark = True
38 43
44torch._dynamo.config.log_level = logging.ERROR
45
46hidet.torch.dynamo_config.use_tensor_core(True)
47hidet.torch.dynamo_config.search_space(2)
48
39 49
40def parse_args(): 50def parse_args():
41 parser = argparse.ArgumentParser( 51 parser = argparse.ArgumentParser(