diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index 26f7941..6fd974e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -5,13 +5,16 @@ from functools import partial | |||
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from typing import Union | 6 | from typing import Union |
| 7 | import math | 7 | import math |
| 8 | import warnings | ||
| 8 | 9 | ||
| 9 | import torch | 10 | import torch |
| 10 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
| 12 | import hidet | ||
| 11 | 13 | ||
| 12 | from accelerate import Accelerator | 14 | from accelerate import Accelerator |
| 13 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
| 17 | from diffusers.models.attention_processor import AttnProcessor | ||
| 15 | from timm.models import create_model | 18 | from timm.models import create_model |
| 16 | import transformers | 19 | import transformers |
| 17 | 20 | ||
| @@ -28,10 +31,18 @@ from training.util import AverageMeter, save_args | |||
| 28 | 31 | ||
| 29 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
| 30 | 33 | ||
| 34 | warnings.filterwarnings('ignore') | ||
| 35 | |||
| 31 | 36 | ||
| 32 | torch.backends.cuda.matmul.allow_tf32 = True | 37 | torch.backends.cuda.matmul.allow_tf32 = True |
| 33 | torch.backends.cudnn.benchmark = True | 38 | torch.backends.cudnn.benchmark = True |
| 34 | 39 | ||
| 40 | # torch._dynamo.config.log_level = logging.WARNING | ||
| 41 | |||
| 42 | hidet.torch.dynamo_config.use_tensor_core(True) | ||
| 43 | hidet.torch.dynamo_config.use_attention(True) | ||
| 44 | hidet.torch.dynamo_config.search_space(0) | ||
| 45 | |||
| 35 | 46 | ||
| 36 | def parse_args(): | 47 | def parse_args(): |
| 37 | parser = argparse.ArgumentParser( | 48 | parser = argparse.ArgumentParser( |
| @@ -706,6 +717,19 @@ def main(): | |||
| 706 | if args.use_xformers: | 717 | if args.use_xformers: |
| 707 | vae.set_use_memory_efficient_attention_xformers(True) | 718 | vae.set_use_memory_efficient_attention_xformers(True) |
| 708 | unet.enable_xformers_memory_efficient_attention() | 719 | unet.enable_xformers_memory_efficient_attention() |
| 720 | elif args.compile_unet: | ||
| 721 | unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
| 722 | |||
| 723 | proc = AttnProcessor() | ||
| 724 | |||
| 725 | def fn_recursive_set_proc(module: torch.nn.Module): | ||
| 726 | if hasattr(module, "processor"): | ||
| 727 | module.processor = proc | ||
| 728 | |||
| 729 | for child in module.children(): | ||
| 730 | fn_recursive_set_proc(child) | ||
| 731 | |||
| 732 | fn_recursive_set_proc(unet) | ||
| 709 | 733 | ||
| 710 | if args.gradient_checkpointing: | 734 | if args.gradient_checkpointing: |
| 711 | unet.enable_gradient_checkpointing() | 735 | unet.enable_gradient_checkpointing() |
