summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py
index 26f7941..6fd974e 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -5,13 +5,16 @@ from functools import partial
5from pathlib import Path 5from pathlib import Path
6from typing import Union 6from typing import Union
7import math 7import math
8import warnings
8 9
9import torch 10import torch
10import torch.utils.checkpoint 11import torch.utils.checkpoint
12import hidet
11 13
12from accelerate import Accelerator 14from accelerate import Accelerator
13from accelerate.logging import get_logger 15from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
17from diffusers.models.attention_processor import AttnProcessor
15from timm.models import create_model 18from timm.models import create_model
16import transformers 19import transformers
17 20
@@ -28,10 +31,18 @@ from training.util import AverageMeter, save_args
28 31
29logger = get_logger(__name__) 32logger = get_logger(__name__)
30 33
34warnings.filterwarnings('ignore')
35
31 36
32torch.backends.cuda.matmul.allow_tf32 = True 37torch.backends.cuda.matmul.allow_tf32 = True
33torch.backends.cudnn.benchmark = True 38torch.backends.cudnn.benchmark = True
34 39
40# torch._dynamo.config.log_level = logging.WARNING
41
42hidet.torch.dynamo_config.use_tensor_core(True)
43hidet.torch.dynamo_config.use_attention(True)
44hidet.torch.dynamo_config.search_space(0)
45
35 46
36def parse_args(): 47def parse_args():
37 parser = argparse.ArgumentParser( 48 parser = argparse.ArgumentParser(
@@ -706,6 +717,19 @@ def main():
706 if args.use_xformers: 717 if args.use_xformers:
707 vae.set_use_memory_efficient_attention_xformers(True) 718 vae.set_use_memory_efficient_attention_xformers(True)
708 unet.enable_xformers_memory_efficient_attention() 719 unet.enable_xformers_memory_efficient_attention()
720 elif args.compile_unet:
721 unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False
722
723 proc = AttnProcessor()
724
725 def fn_recursive_set_proc(module: torch.nn.Module):
726 if hasattr(module, "processor"):
727 module.processor = proc
728
729 for child in module.children():
730 fn_recursive_set_proc(child)
731
732 fn_recursive_set_proc(unet)
709 733
710 if args.gradient_checkpointing: 734 if args.gradient_checkpointing:
711 unet.enable_gradient_checkpointing() 735 unet.enable_gradient_checkpointing()