From 8e0e47217b7e18288eaa9462c6bbecf7387f3d89 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 28 Apr 2023 23:51:40 +0200 Subject: Support torch.compile --- train_lora.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index d5aa78d..64346bc 100644 --- a/train_lora.py +++ b/train_lora.py @@ -5,9 +5,12 @@ import itertools from pathlib import Path from functools import partial import math +import warnings import torch +import torch._dynamo import torch.utils.checkpoint +import hidet from accelerate import Accelerator from accelerate.logging import get_logger @@ -32,10 +35,17 @@ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] logger = get_logger(__name__) +warnings.filterwarnings('ignore') + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True +torch._dynamo.config.log_level = logging.ERROR + +hidet.torch.dynamo_config.use_tensor_core(True) +hidet.torch.dynamo_config.search_space(2) + def parse_args(): parser = argparse.ArgumentParser( -- cgit v1.2.3-54-g00ecf