summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py
index a58bef7..12d7e72 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -49,7 +49,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
49torch.backends.cudnn.benchmark = True 49torch.backends.cudnn.benchmark = True
50 50
51torch._dynamo.config.log_level = logging.WARNING 51torch._dynamo.config.log_level = logging.WARNING
52torch._dynamo.config.suppress_errors = True 52# torch._dynamo.config.suppress_errors = True
53 53
54hidet.torch.dynamo_config.use_tensor_core(True) 54hidet.torch.dynamo_config.use_tensor_core(True)
55hidet.torch.dynamo_config.search_space(0) 55hidet.torch.dynamo_config.search_space(0)
@@ -992,6 +992,7 @@ def main():
992 VlpnDataModule, 992 VlpnDataModule,
993 data_file=args.train_data_file, 993 data_file=args.train_data_file,
994 tokenizer=tokenizer, 994 tokenizer=tokenizer,
995 constant_prompt_length=args.compile_unet,
995 class_subdir=args.class_image_dir, 996 class_subdir=args.class_image_dir,
996 with_guidance=args.guidance_scale != 0, 997 with_guidance=args.guidance_scale != 0,
997 num_class_images=args.num_class_images, 998 num_class_images=args.num_class_images,