diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index a58bef7..12d7e72 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -49,7 +49,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
| 49 | torch.backends.cudnn.benchmark = True | 49 | torch.backends.cudnn.benchmark = True |
| 50 | 50 | ||
| 51 | torch._dynamo.config.log_level = logging.WARNING | 51 | torch._dynamo.config.log_level = logging.WARNING |
| 52 | torch._dynamo.config.suppress_errors = True | 52 | # torch._dynamo.config.suppress_errors = True |
| 53 | 53 | ||
| 54 | hidet.torch.dynamo_config.use_tensor_core(True) | 54 | hidet.torch.dynamo_config.use_tensor_core(True) |
| 55 | hidet.torch.dynamo_config.search_space(0) | 55 | hidet.torch.dynamo_config.search_space(0) |
| @@ -992,6 +992,7 @@ def main(): | |||
| 992 | VlpnDataModule, | 992 | VlpnDataModule, |
| 993 | data_file=args.train_data_file, | 993 | data_file=args.train_data_file, |
| 994 | tokenizer=tokenizer, | 994 | tokenizer=tokenizer, |
| 995 | constant_prompt_length=args.compile_unet, | ||
| 995 | class_subdir=args.class_image_dir, | 996 | class_subdir=args.class_image_dir, |
| 996 | with_guidance=args.guidance_scale != 0, | 997 | with_guidance=args.guidance_scale != 0, |
| 997 | num_class_images=args.num_class_images, | 998 | num_class_images=args.num_class_images, |
