diff options
author | Volpeon <git@volpeon.ink> | 2023-04-29 10:47:08 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-29 10:47:08 +0200 |
commit | 5188d1e4c1d6aed3c510e5a534e39e40d8a3365d (patch) | |
tree | 9d05b8912aea05fad5c1bdc3170f38cd0ce1ed5e | |
parent | Fixed model compilation (diff) | |
download | textual-inversion-diff-5188d1e4c1d6aed3c510e5a534e39e40d8a3365d.tar.gz textual-inversion-diff-5188d1e4c1d6aed3c510e5a534e39e40d8a3365d.tar.bz2 textual-inversion-diff-5188d1e4c1d6aed3c510e5a534e39e40d8a3365d.zip |
Optional xformers
-rw-r--r-- | train_lora.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index 74afeed..d95dbb9 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -44,7 +44,7 @@ torch.backends.cudnn.benchmark = True | |||
44 | torch._dynamo.config.log_level = logging.ERROR | 44 | torch._dynamo.config.log_level = logging.ERROR |
45 | 45 | ||
46 | hidet.torch.dynamo_config.use_tensor_core(True) | 46 | hidet.torch.dynamo_config.use_tensor_core(True) |
47 | hidet.torch.dynamo_config.search_space(2) | 47 | hidet.torch.dynamo_config.search_space(1) |
48 | 48 | ||
49 | 49 | ||
50 | def parse_args(): | 50 | def parse_args(): |
@@ -461,6 +461,11 @@ def parse_args(): | |||
461 | help="Compile UNet with Torch Dynamo.", | 461 | help="Compile UNet with Torch Dynamo.", |
462 | ) | 462 | ) |
463 | parser.add_argument( | 463 | parser.add_argument( |
464 | "--use_xformers", | ||
465 | action="store_true", | ||
466 | help="Use xformers.", | ||
467 | ) | ||
468 | parser.add_argument( | ||
464 | "--lora_rank", | 469 | "--lora_rank", |
465 | type=int, | 470 | type=int, |
466 | default=256, | 471 | default=256, |
@@ -715,8 +720,10 @@ def main(): | |||
715 | text_encoder = LoraModel(text_encoder_config, text_encoder) | 720 | text_encoder = LoraModel(text_encoder_config, text_encoder) |
716 | 721 | ||
717 | vae.enable_slicing() | 722 | vae.enable_slicing() |
718 | vae.set_use_memory_efficient_attention_xformers(True) | 723 | |
719 | unet.enable_xformers_memory_efficient_attention() | 724 | if args.use_xformers: |
725 | vae.set_use_memory_efficient_attention_xformers(True) | ||
726 | unet.enable_xformers_memory_efficient_attention() | ||
720 | 727 | ||
721 | if args.gradient_checkpointing: | 728 | if args.gradient_checkpointing: |
722 | unet.enable_gradient_checkpointing() | 729 | unet.enable_gradient_checkpointing() |