diff options
author | Volpeon <git@volpeon.ink> | 2023-05-16 12:59:08 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-16 12:59:08 +0200 |
commit | 1aace3e44dae0489130039714f67d980628c92ec (patch) | |
tree | 59a972b64bb3a3253e310055fc24381db68e8608 /train_lora.py | |
parent | Patch xformers to cast dtypes (diff) | |
download | textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.tar.gz textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.tar.bz2 textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.zip |
Avoid model recompilation due to varying prompt lengths
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index a58bef7..12d7e72 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -49,7 +49,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
49 | torch.backends.cudnn.benchmark = True | 49 | torch.backends.cudnn.benchmark = True |
50 | 50 | ||
51 | torch._dynamo.config.log_level = logging.WARNING | 51 | torch._dynamo.config.log_level = logging.WARNING |
52 | torch._dynamo.config.suppress_errors = True | 52 | # torch._dynamo.config.suppress_errors = True |
53 | 53 | ||
54 | hidet.torch.dynamo_config.use_tensor_core(True) | 54 | hidet.torch.dynamo_config.use_tensor_core(True) |
55 | hidet.torch.dynamo_config.search_space(0) | 55 | hidet.torch.dynamo_config.search_space(0) |
@@ -992,6 +992,7 @@ def main(): | |||
992 | VlpnDataModule, | 992 | VlpnDataModule, |
993 | data_file=args.train_data_file, | 993 | data_file=args.train_data_file, |
994 | tokenizer=tokenizer, | 994 | tokenizer=tokenizer, |
995 | constant_prompt_length=args.compile_unet, | ||
995 | class_subdir=args.class_image_dir, | 996 | class_subdir=args.class_image_dir, |
996 | with_guidance=args.guidance_scale != 0, | 997 | with_guidance=args.guidance_scale != 0, |
997 | num_class_images=args.num_class_images, | 998 | num_class_images=args.num_class_images, |